diff --git a/wepoll.c b/wepoll.c index 9504c30..11a440f 100644 --- a/wepoll.c +++ b/wepoll.c @@ -126,7 +126,7 @@ WEPOLL_EXPORT int epoll_wait(HANDLE ephnd, #pragma clang diagnostic ignored "-Wreserved-id-macro" #endif -#if defined(_WIN32_WINNT) +#ifdef _WIN32_WINNT #undef _WIN32_WINNT #endif @@ -170,10 +170,7 @@ typedef NTSTATUS* PNTSTATUS; #endif typedef struct _IO_STATUS_BLOCK { - union { - NTSTATUS Status; - PVOID Pointer; - }; + NTSTATUS Status; ULONG_PTR Information; } IO_STATUS_BLOCK, *PIO_STATUS_BLOCK; @@ -187,6 +184,9 @@ typedef struct _LSA_UNICODE_STRING { PWSTR Buffer; } LSA_UNICODE_STRING, *PLSA_UNICODE_STRING, UNICODE_STRING, *PUNICODE_STRING; +#define RTL_CONSTANT_STRING(s) \ + { sizeof(s) - sizeof((s)[0]), sizeof(s), s } + typedef struct _OBJECT_ATTRIBUTES { ULONG Length; HANDLE RootDirectory; @@ -196,7 +196,29 @@ typedef struct _OBJECT_ATTRIBUTES { PVOID SecurityQualityOfService; } OBJECT_ATTRIBUTES, *POBJECT_ATTRIBUTES; -#define NTDLL_IMPORT_LIST(X) \ +#define RTL_CONSTANT_OBJECT_ATTRIBUTES(ObjectName, Attributes) \ + { sizeof(OBJECT_ATTRIBUTES), NULL, ObjectName, Attributes, NULL, NULL } + +#ifndef FILE_OPEN +#define FILE_OPEN 0x00000001UL +#endif + +#define NT_NTDLL_IMPORT_LIST(X) \ + X(NTSTATUS, \ + NTAPI, \ + NtCreateFile, \ + (PHANDLE FileHandle, \ + ACCESS_MASK DesiredAccess, \ + POBJECT_ATTRIBUTES ObjectAttributes, \ + PIO_STATUS_BLOCK IoStatusBlock, \ + PLARGE_INTEGER AllocationSize, \ + ULONG FileAttributes, \ + ULONG ShareAccess, \ + ULONG CreateDisposition, \ + ULONG CreateOptions, \ + PVOID EaBuffer, \ + ULONG EaLength)) \ + \ X(NTSTATUS, \ NTAPI, \ NtDeviceIoControlFile, \ @@ -233,7 +255,7 @@ typedef struct _OBJECT_ATTRIBUTES { #define X(return_type, attributes, name, parameters) \ WEPOLL_INTERNAL_VAR return_type(attributes* name) parameters; -NTDLL_IMPORT_LIST(X) +NT_NTDLL_IMPORT_LIST(X) #undef X #include @@ -282,12 +304,10 @@ typedef struct _AFD_POLL_INFO { AFD_POLL_HANDLE_INFO Handles[1]; } AFD_POLL_INFO, *PAFD_POLL_INFO; -WEPOLL_INTERNAL int afd_global_init(void); +WEPOLL_INTERNAL int afd_create_helper_handle(HANDLE iocp, + HANDLE* afd_helper_handle_out); -WEPOLL_INTERNAL int afd_create_driver_socket(HANDLE iocp, - SOCKET* driver_socket_out); - -WEPOLL_INTERNAL int afd_poll(SOCKET driver_socket, +WEPOLL_INTERNAL int afd_poll(HANDLE afd_helper_handle, AFD_POLL_INFO* poll_info, OVERLAPPED* overlapped); @@ -308,127 +328,54 @@ WEPOLL_INTERNAL void err_set_win_error(DWORD error); WEPOLL_INTERNAL int err_check_handle(HANDLE handle); WEPOLL_INTERNAL int ws_global_init(void); - WEPOLL_INTERNAL SOCKET ws_get_base_socket(SOCKET socket); -WEPOLL_INTERNAL int ws_get_protocol_catalog(WSAPROTOCOL_INFOW** infos_out, - size_t* infos_count_out); #define IOCTL_AFD_POLL 0x00012024 -/* clang-format off */ -static const GUID AFD__PROVIDER_GUID_LIST[] = { - /* MSAFD Tcpip [TCP+UDP+RAW / IP] */ - {0xe70f1aa0, 0xab8b, 0x11cf, - {0x8c, 0xa3, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92}}, - /* MSAFD Tcpip [TCP+UDP+RAW / IPv6] */ - {0xf9eab0c0, 0x26d4, 0x11d0, - {0xbb, 0xbf, 0x00, 0xaa, 0x00, 0x6c, 0x34, 0xe4}}, - /* MSAFD RfComm [Bluetooth] */ - {0x9fc48064, 0x7298, 0x43e4, - {0xb7, 0xbd, 0x18, 0x1f, 0x20, 0x89, 0x79, 0x2a}}, - /* MSAFD Irda [IrDA] */ - {0x3972523d, 0x2af1, 0x11d1, - {0xb6, 0x55, 0x00, 0x80, 0x5f, 0x36, 0x42, 0xcc}}}; -/* clang-format on */ +static UNICODE_STRING afd__helper_name = + RTL_CONSTANT_STRING(L"\\Device\\Afd\\Wepoll"); -static const int AFD__ANY_PROTOCOL = -1; +static OBJECT_ATTRIBUTES afd__helper_attributes = + RTL_CONSTANT_OBJECT_ATTRIBUTES(&afd__helper_name, 0); -/* This protocol info record is used by afd_create_driver_socket() to create - * sockets that can be used as the first argument to afd_poll(). It is - * populated on startup by afd_global_init(). */ -static WSAPROTOCOL_INFOW afd__driver_socket_protocol_info; +int afd_create_helper_handle(HANDLE iocp, HANDLE* afd_helper_handle_out) { + HANDLE afd_helper_handle; + IO_STATUS_BLOCK iosb; + NTSTATUS status; -static const WSAPROTOCOL_INFOW* afd__find_protocol_info( - const WSAPROTOCOL_INFOW* infos, size_t infos_count, int protocol_id) { - size_t i, j; + /* By opening \Device\Afd without specifying any extended attributes, we'll + * get a handle that lets us talk to the AFD driver, but that doesn't have an + * associated endpoint (so it's not a socket). */ + status = NtCreateFile(&afd_helper_handle, + SYNCHRONIZE, + &afd__helper_attributes, + &iosb, + NULL, + 0, + FILE_SHARE_READ | FILE_SHARE_WRITE, + FILE_OPEN, + 0, + NULL, + 0); + if (status != STATUS_SUCCESS) + return_set_error(-1, RtlNtStatusToDosError(status)); - for (i = 0; i < infos_count; i++) { - const WSAPROTOCOL_INFOW* info = &infos[i]; - - /* Apply protocol id filter. */ - if (protocol_id != AFD__ANY_PROTOCOL && protocol_id != info->iProtocol) - continue; - - /* Filter out non-MSAFD protocols. */ - for (j = 0; j < array_count(AFD__PROVIDER_GUID_LIST); j++) { - if (memcmp(&info->ProviderId, - &AFD__PROVIDER_GUID_LIST[j], - sizeof info->ProviderId) == 0) - return info; - } - } - - return NULL; /* Not found. */ -} - -int afd_global_init(void) { - WSAPROTOCOL_INFOW* infos; - size_t infos_count; - const WSAPROTOCOL_INFOW* afd_info; - - /* Load the winsock catalog. */ - if (ws_get_protocol_catalog(&infos, &infos_count) < 0) - return -1; - - /* Find a WSAPROTOCOL_INFOW structure that we can use to create an MSAFD - * socket. Preferentially we pick a UDP socket, otherwise try TCP or any - * other type. */ - for (;;) { - afd_info = afd__find_protocol_info(infos, infos_count, IPPROTO_UDP); - if (afd_info != NULL) - break; - - afd_info = afd__find_protocol_info(infos, infos_count, IPPROTO_TCP); - if (afd_info != NULL) - break; - - afd_info = afd__find_protocol_info(infos, infos_count, AFD__ANY_PROTOCOL); - if (afd_info != NULL) - break; - - free(infos); - return_set_error(-1, WSAENETDOWN); /* No suitable protocol found. */ - } - - /* Copy found protocol information from the catalog to a static buffer. */ - afd__driver_socket_protocol_info = *afd_info; - - free(infos); - return 0; -} - -int afd_create_driver_socket(HANDLE iocp, SOCKET* driver_socket_out) { - SOCKET socket; - - socket = WSASocketW(afd__driver_socket_protocol_info.iAddressFamily, - afd__driver_socket_protocol_info.iSocketType, - afd__driver_socket_protocol_info.iProtocol, - &afd__driver_socket_protocol_info, - 0, - WSA_FLAG_OVERLAPPED); - if (socket == INVALID_SOCKET) - return_map_error(-1); - - /* TODO: use WSA_FLAG_NOINHERIT on Windows versions that support it. */ - if (!SetHandleInformation((HANDLE) socket, HANDLE_FLAG_INHERIT, 0)) + if (CreateIoCompletionPort(afd_helper_handle, iocp, 0, 0) == NULL) goto error; - if (CreateIoCompletionPort((HANDLE) socket, iocp, 0, 0) == NULL) - goto error; - - if (!SetFileCompletionNotificationModes((HANDLE) socket, + if (!SetFileCompletionNotificationModes(afd_helper_handle, FILE_SKIP_SET_EVENT_ON_HANDLE)) goto error; - *driver_socket_out = socket; + *afd_helper_handle_out = afd_helper_handle; return 0; error: - closesocket(socket); + CloseHandle(afd_helper_handle); return_map_error(-1); } -int afd_poll(SOCKET driver_socket, +int afd_poll(HANDLE afd_helper_handle, AFD_POLL_INFO* poll_info, OVERLAPPED* overlapped) { IO_STATUS_BLOCK* iosb; @@ -452,7 +399,7 @@ int afd_poll(SOCKET driver_socket, } iosb->Status = STATUS_PENDING; - status = NtDeviceIoControlFile((HANDLE) driver_socket, + status = NtDeviceIoControlFile(afd_helper_handle, event, NULL, apc_context, @@ -513,7 +460,8 @@ WEPOLL_INTERNAL void poll_group_delete(poll_group_t* poll_group); WEPOLL_INTERNAL poll_group_t* poll_group_from_queue_node( queue_node_t* queue_node); -WEPOLL_INTERNAL SOCKET poll_group_get_socket(poll_group_t* poll_group); +WEPOLL_INTERNAL HANDLE + poll_group_get_afd_helper_handle(poll_group_t* poll_group); /* N.b.: the tree functions do not set errno or LastError when they fail. Each * of the API functions has at most one failure mode. It is up to the caller to @@ -961,7 +909,7 @@ static BOOL CALLBACK init__once_callback(INIT_ONCE* once, unused_var(context); /* N.b. that initialization order matters here. */ - if (ws_global_init() < 0 || nt_global_init() < 0 || afd_global_init() < 0 || + if (ws_global_init() < 0 || nt_global_init() < 0 || reflock_global_init() < 0 || epoll_global_init() < 0) return FALSE; @@ -979,7 +927,7 @@ int init(void) { #define X(return_type, attributes, name, parameters) \ WEPOLL_INTERNAL return_type(attributes* name) parameters = NULL; -NTDLL_IMPORT_LIST(X) +NT_NTDLL_IMPORT_LIST(X) #undef X int nt_global_init(void) { @@ -993,7 +941,7 @@ int nt_global_init(void) { name = (return_type(attributes*) parameters) GetProcAddress(ntdll, #name); \ if (name == NULL) \ return -1; - NTDLL_IMPORT_LIST(X) + NT_NTDLL_IMPORT_LIST(X) #undef X return 0; @@ -1006,7 +954,7 @@ static const size_t POLL_GROUP__MAX_GROUP_SIZE = 32; typedef struct poll_group { port_state_t* port_state; queue_node_t queue_node; - SOCKET socket; + HANDLE afd_helper_handle; size_t group_size; } poll_group_t; @@ -1020,7 +968,8 @@ static poll_group_t* poll_group__new(port_state_t* port_state) { queue_node_init(&poll_group->queue_node); poll_group->port_state = port_state; - if (afd_create_driver_socket(port_state->iocp, &poll_group->socket) < 0) { + if (afd_create_helper_handle(port_state->iocp, + &poll_group->afd_helper_handle) < 0) { free(poll_group); return NULL; } @@ -1032,7 +981,7 @@ static poll_group_t* poll_group__new(port_state_t* port_state) { void poll_group_delete(poll_group_t* poll_group) { assert(poll_group->group_size == 0); - closesocket(poll_group->socket); + CloseHandle(poll_group->afd_helper_handle); queue_remove(&poll_group->queue_node); free(poll_group); } @@ -1041,8 +990,8 @@ poll_group_t* poll_group_from_queue_node(queue_node_t* queue_node) { return container_of(queue_node, poll_group_t, queue_node); } -SOCKET poll_group_get_socket(poll_group_t* poll_group) { - return poll_group->socket; +HANDLE poll_group_get_afd_helper_handle(poll_group_t* poll_group) { + return poll_group->afd_helper_handle; } poll_group_t* poll_group_acquire(port_state_t* port_state) { @@ -1621,13 +1570,13 @@ static inline void sock__free(sock_state_t* sock_state) { } static int sock__cancel_poll(sock_state_t* sock_state) { - HANDLE driver_handle = - (HANDLE)(uintptr_t) poll_group_get_socket(sock_state->poll_group); + HANDLE afd_helper_handle = + poll_group_get_afd_helper_handle(sock_state->poll_group); assert(sock_state->poll_status == SOCK__POLL_PENDING); /* CancelIoEx() may fail with ERROR_NOT_FOUND if the overlapped operation has * already completed. This is not a problem and we proceed normally. */ - if (!CancelIoEx(driver_handle, &sock_state->overlapped) && + if (!CancelIoEx(afd_helper_handle, &sock_state->overlapped) && GetLastError() != ERROR_NOT_FOUND) return_map_error(-1); @@ -1806,7 +1755,7 @@ int sock_update(port_state_t* port_state, sock_state_t* sock_state) { memset(&sock_state->overlapped, 0, sizeof sock_state->overlapped); - if (afd_poll(poll_group_get_socket(sock_state->poll_group), + if (afd_poll(poll_group_get_afd_helper_handle(sock_state->poll_group), &sock_state->poll_info, &sock_state->overlapped) < 0) { switch (GetLastError()) { @@ -2191,8 +2140,6 @@ tree_node_t* tree_root(const tree_t* tree) { #define SIO_BASE_HANDLE 0x48000022 #endif -#define WS__INITIAL_CATALOG_BUFFER_SIZE 0x4000 /* 16kb. */ - int ws_global_init(void) { int r; WSADATA wsa_data; @@ -2221,30 +2168,3 @@ SOCKET ws_get_base_socket(SOCKET socket) { return base_socket; } - -/* Retrieves a copy of the winsock catalog. - * The infos pointer must be released by the caller with free(). */ -int ws_get_protocol_catalog(WSAPROTOCOL_INFOW** infos_out, - size_t* infos_count_out) { - DWORD buffer_size = WS__INITIAL_CATALOG_BUFFER_SIZE; - int count; - WSAPROTOCOL_INFOW* infos; - -retry: - infos = malloc(buffer_size); - if (infos == NULL) - return_set_error(-1, ERROR_NOT_ENOUGH_MEMORY); - - count = WSAEnumProtocolsW(NULL, infos, &buffer_size); - if (count == SOCKET_ERROR) { - free(infos); - if (WSAGetLastError() == WSAENOBUFS) - goto retry; /* Try again with bigger buffer size. */ - else - return_map_error(-1); - } - - *infos_out = infos; - *infos_count_out = (size_t) count; - return 0; -}