diff --git a/src/afd.c b/src/afd.c index 38025fa..886fd22 100644 --- a/src/afd.c +++ b/src/afd.c @@ -10,120 +10,50 @@ #define IOCTL_AFD_POLL 0x00012024 -/* clang-format off */ -static const GUID AFD__PROVIDER_GUID_LIST[] = { - /* MSAFD Tcpip [TCP+UDP+RAW / IP] */ - {0xe70f1aa0, 0xab8b, 0x11cf, - {0x8c, 0xa3, 0x00, 0x80, 0x5f, 0x48, 0xa1, 0x92}}, - /* MSAFD Tcpip [TCP+UDP+RAW / IPv6] */ - {0xf9eab0c0, 0x26d4, 0x11d0, - {0xbb, 0xbf, 0x00, 0xaa, 0x00, 0x6c, 0x34, 0xe4}}, - /* MSAFD RfComm [Bluetooth] */ - {0x9fc48064, 0x7298, 0x43e4, - {0xb7, 0xbd, 0x18, 0x1f, 0x20, 0x89, 0x79, 0x2a}}, - /* MSAFD Irda [IrDA] */ - {0x3972523d, 0x2af1, 0x11d1, - {0xb6, 0x55, 0x00, 0x80, 0x5f, 0x36, 0x42, 0xcc}}}; -/* clang-format on */ +static UNICODE_STRING afd__helper_name = + RTL_CONSTANT_STRING(L"\\Device\\Afd\\Wepoll"); -static const int AFD__ANY_PROTOCOL = -1; +static OBJECT_ATTRIBUTES afd__helper_attributes = + RTL_CONSTANT_OBJECT_ATTRIBUTES(&afd__helper_name, 0); -/* This protocol info record is used by afd_create_driver_socket() to create - * sockets that can be used as the first argument to afd_poll(). It is - * populated on startup by afd_global_init(). */ -static WSAPROTOCOL_INFOW afd__driver_socket_protocol_info; +int afd_create_helper_handle(HANDLE iocp, HANDLE* afd_helper_handle_out) { + HANDLE afd_helper_handle; + IO_STATUS_BLOCK iosb; + NTSTATUS status; -static const WSAPROTOCOL_INFOW* afd__find_protocol_info( - const WSAPROTOCOL_INFOW* infos, size_t infos_count, int protocol_id) { - size_t i, j; + /* By opening \Device\Afd without specifying any extended attributes, we'll + * get a handle that lets us talk to the AFD driver, but that doesn't have an + * associated endpoint (so it's not a socket). */ + status = NtCreateFile(&afd_helper_handle, + SYNCHRONIZE, + &afd__helper_attributes, + &iosb, + NULL, + 0, + FILE_SHARE_READ | FILE_SHARE_WRITE, + FILE_OPEN, + 0, + NULL, + 0); + if (status != STATUS_SUCCESS) + return_set_error(-1, RtlNtStatusToDosError(status)); - for (i = 0; i < infos_count; i++) { - const WSAPROTOCOL_INFOW* info = &infos[i]; - - /* Apply protocol id filter. */ - if (protocol_id != AFD__ANY_PROTOCOL && protocol_id != info->iProtocol) - continue; - - /* Filter out non-MSAFD protocols. */ - for (j = 0; j < array_count(AFD__PROVIDER_GUID_LIST); j++) { - if (memcmp(&info->ProviderId, - &AFD__PROVIDER_GUID_LIST[j], - sizeof info->ProviderId) == 0) - return info; - } - } - - return NULL; /* Not found. */ -} - -int afd_global_init(void) { - WSAPROTOCOL_INFOW* infos; - size_t infos_count; - const WSAPROTOCOL_INFOW* afd_info; - - /* Load the winsock catalog. */ - if (ws_get_protocol_catalog(&infos, &infos_count) < 0) - return -1; - - /* Find a WSAPROTOCOL_INFOW structure that we can use to create an MSAFD - * socket. Preferentially we pick a UDP socket, otherwise try TCP or any - * other type. */ - for (;;) { - afd_info = afd__find_protocol_info(infos, infos_count, IPPROTO_UDP); - if (afd_info != NULL) - break; - - afd_info = afd__find_protocol_info(infos, infos_count, IPPROTO_TCP); - if (afd_info != NULL) - break; - - afd_info = afd__find_protocol_info(infos, infos_count, AFD__ANY_PROTOCOL); - if (afd_info != NULL) - break; - - free(infos); - return_set_error(-1, WSAENETDOWN); /* No suitable protocol found. */ - } - - /* Copy found protocol information from the catalog to a static buffer. */ - afd__driver_socket_protocol_info = *afd_info; - - free(infos); - return 0; -} - -int afd_create_driver_socket(HANDLE iocp, SOCKET* driver_socket_out) { - SOCKET socket; - - socket = WSASocketW(afd__driver_socket_protocol_info.iAddressFamily, - afd__driver_socket_protocol_info.iSocketType, - afd__driver_socket_protocol_info.iProtocol, - &afd__driver_socket_protocol_info, - 0, - WSA_FLAG_OVERLAPPED); - if (socket == INVALID_SOCKET) - return_map_error(-1); - - /* TODO: use WSA_FLAG_NOINHERIT on Windows versions that support it. */ - if (!SetHandleInformation((HANDLE) socket, HANDLE_FLAG_INHERIT, 0)) + if (CreateIoCompletionPort(afd_helper_handle, iocp, 0, 0) == NULL) goto error; - if (CreateIoCompletionPort((HANDLE) socket, iocp, 0, 0) == NULL) - goto error; - - if (!SetFileCompletionNotificationModes((HANDLE) socket, + if (!SetFileCompletionNotificationModes(afd_helper_handle, FILE_SKIP_SET_EVENT_ON_HANDLE)) goto error; - *driver_socket_out = socket; + *afd_helper_handle_out = afd_helper_handle; return 0; error: - closesocket(socket); + CloseHandle(afd_helper_handle); return_map_error(-1); } -int afd_poll(SOCKET driver_socket, +int afd_poll(HANDLE afd_helper_handle, AFD_POLL_INFO* poll_info, OVERLAPPED* overlapped) { IO_STATUS_BLOCK* iosb; @@ -147,7 +77,7 @@ int afd_poll(SOCKET driver_socket, } iosb->Status = STATUS_PENDING; - status = NtDeviceIoControlFile((HANDLE) driver_socket, + status = NtDeviceIoControlFile(afd_helper_handle, event, NULL, apc_context, diff --git a/src/afd.h b/src/afd.h index a1180e7..b1e403c 100644 --- a/src/afd.h +++ b/src/afd.h @@ -31,12 +31,10 @@ typedef struct _AFD_POLL_INFO { AFD_POLL_HANDLE_INFO Handles[1]; } AFD_POLL_INFO, *PAFD_POLL_INFO; -WEPOLL_INTERNAL int afd_global_init(void); +WEPOLL_INTERNAL int afd_create_helper_handle(HANDLE iocp, + HANDLE* afd_helper_handle_out); -WEPOLL_INTERNAL int afd_create_driver_socket(HANDLE iocp, - SOCKET* driver_socket_out); - -WEPOLL_INTERNAL int afd_poll(SOCKET driver_socket, +WEPOLL_INTERNAL int afd_poll(HANDLE afd_helper_handle, AFD_POLL_INFO* poll_info, OVERLAPPED* overlapped); diff --git a/src/init.c b/src/init.c index 69a270b..0deae84 100644 --- a/src/init.c +++ b/src/init.c @@ -19,7 +19,7 @@ static BOOL CALLBACK init__once_callback(INIT_ONCE* once, unused_var(context); /* N.b. that initialization order matters here. */ - if (ws_global_init() < 0 || nt_global_init() < 0 || afd_global_init() < 0 || + if (ws_global_init() < 0 || nt_global_init() < 0 || reflock_global_init() < 0 || epoll_global_init() < 0) return FALSE; diff --git a/src/nt.h b/src/nt.h index 62cd98b..89c7eac 100644 --- a/src/nt.h +++ b/src/nt.h @@ -40,6 +40,9 @@ typedef struct _LSA_UNICODE_STRING { PWSTR Buffer; } LSA_UNICODE_STRING, *PLSA_UNICODE_STRING, UNICODE_STRING, *PUNICODE_STRING; +#define RTL_CONSTANT_STRING(s) \ + { sizeof(s) - sizeof((s)[0]), sizeof(s), s } + typedef struct _OBJECT_ATTRIBUTES { ULONG Length; HANDLE RootDirectory; @@ -49,7 +52,29 @@ typedef struct _OBJECT_ATTRIBUTES { PVOID SecurityQualityOfService; } OBJECT_ATTRIBUTES, *POBJECT_ATTRIBUTES; +#define RTL_CONSTANT_OBJECT_ATTRIBUTES(ObjectName, Attributes) \ + { sizeof(OBJECT_ATTRIBUTES), NULL, ObjectName, Attributes, NULL, NULL } + +#ifndef FILE_OPEN +#define FILE_OPEN 0x00000001UL +#endif + #define NT_NTDLL_IMPORT_LIST(X) \ + X(NTSTATUS, \ + NTAPI, \ + NtCreateFile, \ + (PHANDLE FileHandle, \ + ACCESS_MASK DesiredAccess, \ + POBJECT_ATTRIBUTES ObjectAttributes, \ + PIO_STATUS_BLOCK IoStatusBlock, \ + PLARGE_INTEGER AllocationSize, \ + ULONG FileAttributes, \ + ULONG ShareAccess, \ + ULONG CreateDisposition, \ + ULONG CreateOptions, \ + PVOID EaBuffer, \ + ULONG EaLength)) \ + \ X(NTSTATUS, \ NTAPI, \ NtDeviceIoControlFile, \ diff --git a/src/poll-group.c b/src/poll-group.c index 9a1b16a..53b9ca2 100644 --- a/src/poll-group.c +++ b/src/poll-group.c @@ -13,7 +13,7 @@ static const size_t POLL_GROUP__MAX_GROUP_SIZE = 32; typedef struct poll_group { port_state_t* port_state; queue_node_t queue_node; - SOCKET socket; + HANDLE afd_helper_handle; size_t group_size; } poll_group_t; @@ -27,7 +27,8 @@ static poll_group_t* poll_group__new(port_state_t* port_state) { queue_node_init(&poll_group->queue_node); poll_group->port_state = port_state; - if (afd_create_driver_socket(port_state->iocp, &poll_group->socket) < 0) { + if (afd_create_helper_handle(port_state->iocp, + &poll_group->afd_helper_handle) < 0) { free(poll_group); return NULL; } @@ -39,7 +40,7 @@ static poll_group_t* poll_group__new(port_state_t* port_state) { void poll_group_delete(poll_group_t* poll_group) { assert(poll_group->group_size == 0); - closesocket(poll_group->socket); + CloseHandle(poll_group->afd_helper_handle); queue_remove(&poll_group->queue_node); free(poll_group); } @@ -48,8 +49,8 @@ poll_group_t* poll_group_from_queue_node(queue_node_t* queue_node) { return container_of(queue_node, poll_group_t, queue_node); } -SOCKET poll_group_get_socket(poll_group_t* poll_group) { - return poll_group->socket; +HANDLE poll_group_get_afd_helper_handle(poll_group_t* poll_group) { + return poll_group->afd_helper_handle; } poll_group_t* poll_group_acquire(port_state_t* port_state) { diff --git a/src/poll-group.h b/src/poll-group.h index 659080e..2e5cd09 100644 --- a/src/poll-group.h +++ b/src/poll-group.h @@ -16,6 +16,7 @@ WEPOLL_INTERNAL void poll_group_delete(poll_group_t* poll_group); WEPOLL_INTERNAL poll_group_t* poll_group_from_queue_node( queue_node_t* queue_node); -WEPOLL_INTERNAL SOCKET poll_group_get_socket(poll_group_t* poll_group); +WEPOLL_INTERNAL HANDLE + poll_group_get_afd_helper_handle(poll_group_t* poll_group); #endif /* WEPOLL_POLL_GROUP_H_ */ diff --git a/src/sock.c b/src/sock.c index 4276362..c17c0f9 100644 --- a/src/sock.c +++ b/src/sock.c @@ -47,13 +47,13 @@ static inline void sock__free(sock_state_t* sock_state) { } static int sock__cancel_poll(sock_state_t* sock_state) { - HANDLE driver_handle = - (HANDLE)(uintptr_t) poll_group_get_socket(sock_state->poll_group); + HANDLE afd_helper_handle = + poll_group_get_afd_helper_handle(sock_state->poll_group); assert(sock_state->poll_status == SOCK__POLL_PENDING); /* CancelIoEx() may fail with ERROR_NOT_FOUND if the overlapped operation has * already completed. This is not a problem and we proceed normally. */ - if (!CancelIoEx(driver_handle, &sock_state->overlapped) && + if (!CancelIoEx(afd_helper_handle, &sock_state->overlapped) && GetLastError() != ERROR_NOT_FOUND) return_map_error(-1); @@ -232,7 +232,7 @@ int sock_update(port_state_t* port_state, sock_state_t* sock_state) { memset(&sock_state->overlapped, 0, sizeof sock_state->overlapped); - if (afd_poll(poll_group_get_socket(sock_state->poll_group), + if (afd_poll(poll_group_get_afd_helper_handle(sock_state->poll_group), &sock_state->poll_info, &sock_state->overlapped) < 0) { switch (GetLastError()) { diff --git a/src/ws.c b/src/ws.c index 6c67740..74b6bdf 100644 --- a/src/ws.c +++ b/src/ws.c @@ -10,8 +10,6 @@ #define SIO_BASE_HANDLE 0x48000022 #endif -#define WS__INITIAL_CATALOG_BUFFER_SIZE 0x4000 /* 16kb. */ - int ws_global_init(void) { int r; WSADATA wsa_data; @@ -40,30 +38,3 @@ SOCKET ws_get_base_socket(SOCKET socket) { return base_socket; } - -/* Retrieves a copy of the winsock catalog. - * The infos pointer must be released by the caller with free(). */ -int ws_get_protocol_catalog(WSAPROTOCOL_INFOW** infos_out, - size_t* infos_count_out) { - DWORD buffer_size = WS__INITIAL_CATALOG_BUFFER_SIZE; - int count; - WSAPROTOCOL_INFOW* infos; - -retry: - infos = malloc(buffer_size); - if (infos == NULL) - return_set_error(-1, ERROR_NOT_ENOUGH_MEMORY); - - count = WSAEnumProtocolsW(NULL, infos, &buffer_size); - if (count == SOCKET_ERROR) { - free(infos); - if (WSAGetLastError() == WSAENOBUFS) - goto retry; /* Try again with bigger buffer size. */ - else - return_map_error(-1); - } - - *infos_out = infos; - *infos_count_out = (size_t) count; - return 0; -} diff --git a/src/ws.h b/src/ws.h index f2dd66e..d688d27 100644 --- a/src/ws.h +++ b/src/ws.h @@ -6,9 +6,6 @@ #include "win.h" WEPOLL_INTERNAL int ws_global_init(void); - WEPOLL_INTERNAL SOCKET ws_get_base_socket(SOCKET socket); -WEPOLL_INTERNAL int ws_get_protocol_catalog(WSAPROTOCOL_INFOW** infos_out, - size_t* infos_count_out); #endif /* WEPOLL_WS_H_ */