diff --git a/src/afd.c b/src/afd.c index 29494a2..a426701 100644 --- a/src/afd.c +++ b/src/afd.c @@ -1,3 +1,6 @@ +#include +#include + #include "afd.h" #include "error.h" #include "nt.h" @@ -14,6 +17,8 @@ #define IOCTL_AFD_POLL _AFD_CONTROL_CODE(AFD_POLL, METHOD_BUFFERED) +#define _AFD_ANY_PROTOCOL -1 + /* clang-format off */ static const GUID _AFD_PROVIDER_GUID_LIST[] = { /* MSAFD Tcpip [TCP+UDP+RAW / IP] */ @@ -30,6 +35,101 @@ static const GUID _AFD_PROVIDER_GUID_LIST[] = { {0xb6, 0x55, 0x00, 0x80, 0x5f, 0x36, 0x42, 0xcc}}}; /* clang-format on */ +/* This protocol info record is used by afd_create_driver_socket() to create + * sockets that can be used as the first argument to afd_poll(). It is + * populated on startup by afd_global_init(). + */ +static WSAPROTOCOL_INFOW _afd_driver_socket_template; + +static const WSAPROTOCOL_INFOW* _afd_find_protocol_info( + const WSAPROTOCOL_INFOW* infos, size_t infos_count, int protocol_id) { + size_t i, j; + + for (i = 0; i < infos_count; i++) { + const WSAPROTOCOL_INFOW* info = &infos[i]; + + /* Apply protocol id filter. */ + if (protocol_id != _AFD_ANY_PROTOCOL && protocol_id != info->iProtocol) + continue; + + /* Filter out non-MSAFD protocols. */ + for (j = 0; j < array_count(_AFD_PROVIDER_GUID_LIST); j++) { + if (memcmp(&info->ProviderId, + &_AFD_PROVIDER_GUID_LIST[j], + sizeof info->ProviderId) == 0) + return info; + } + } + + return NULL; /* Not found. */ +} + +int afd_global_init(void) { + WSAPROTOCOL_INFOW* infos; + ssize_t infos_count; + const WSAPROTOCOL_INFOW* afd_info; + + /* Load the winsock catalog. */ + infos_count = ws_get_protocol_catalog(&infos); + if (infos_count < 0) + return_error(-1); + + /* Find a WSAPROTOCOL_INDOW structure that we can use to create an MSAFD + * socket. Preferentially we pick a UDP socket, otherwise try TCP or any + * other type. + */ + do { + afd_info = _afd_find_protocol_info(infos, infos_count, IPPROTO_UDP); + if (afd_info != NULL) + break; + + afd_info = _afd_find_protocol_info(infos, infos_count, IPPROTO_TCP); + if (afd_info != NULL) + break; + + afd_info = _afd_find_protocol_info(infos, infos_count, _AFD_ANY_PROTOCOL); + if (afd_info != NULL) + break; + + free(infos); + return_error(-1, WSAENETDOWN); /* No suitable protocol found. */ + } while (0); + + /* Copy found protocol information from the catalog to a static buffer. */ + _afd_driver_socket_template = *afd_info; + + free(infos); + return 0; +} + +int afd_create_driver_socket(HANDLE iocp, SOCKET* driver_socket_out) { + SOCKET socket; + + socket = WSASocketW(_afd_driver_socket_template.iAddressFamily, + _afd_driver_socket_template.iSocketType, + _afd_driver_socket_template.iProtocol, + &_afd_driver_socket_template, + 0, + WSA_FLAG_OVERLAPPED); + if (socket == INVALID_SOCKET) + return_error(-1); + + /* TODO: use WSA_FLAG_NOINHERIT on Windows versions that support it. */ + if (!SetHandleInformation((HANDLE) socket, HANDLE_FLAG_INHERIT, 0)) + goto error; + + if (CreateIoCompletionPort((HANDLE) socket, iocp, 0, 0) == NULL) + goto error; + + *driver_socket_out = socket; + return 0; + +error:; + DWORD error = GetLastError(); + closesocket(socket); + return_error(-1, error); +} + int afd_poll(SOCKET driver_socket, AFD_POLL_INFO* poll_info, OVERLAPPED* overlapped) { diff --git a/src/afd.h b/src/afd.h index 12b3e0e..07e65fd 100644 --- a/src/afd.h +++ b/src/afd.h @@ -52,6 +52,11 @@ typedef struct _AFD_POLL_INFO { AFD_POLL_HANDLE_INFO Handles[1]; } AFD_POLL_INFO, *PAFD_POLL_INFO; +WEPOLL_INTERNAL int afd_global_init(void); + +WEPOLL_INTERNAL int afd_create_driver_socket(HANDLE iocp, + SOCKET* driver_socket_out); + WEPOLL_INTERNAL int afd_poll(SOCKET driver_socket, AFD_POLL_INFO* poll_info, OVERLAPPED* overlapped); diff --git a/src/init.c b/src/init.c index 3e03af6..199ee90 100644 --- a/src/init.c +++ b/src/init.c @@ -1,5 +1,6 @@ #include +#include "afd.h" #include "api.h" #include "init.h" #include "nt.h" @@ -18,7 +19,7 @@ static BOOL CALLBACK _init_once_callback(INIT_ONCE* once, unused_var(context); /* N.b. that initialization order matters here. */ - if (ws_global_init() < 0 || nt_global_init() < 0 || + if (ws_global_init() < 0 || nt_global_init() < 0 || afd_global_init() < 0 || reflock_global_init() < 0 || api_global_init() < 0) return FALSE; diff --git a/src/poll-group.c b/src/poll-group.c index 076b2f4..c6423d4 100644 --- a/src/poll-group.c +++ b/src/poll-group.c @@ -12,7 +12,6 @@ static const size_t _POLL_GROUP_MAX_SIZE = 32; typedef struct poll_group_allocator { ep_port_t* port_info; queue_t poll_group_queue; - WSAPROTOCOL_INFOW protocol_info; } poll_group_allocator_t; typedef struct poll_group { @@ -22,35 +21,6 @@ typedef struct poll_group { size_t group_size; } poll_group_t; -static int _poll_group_create_socket(poll_group_t* poll_group, - WSAPROTOCOL_INFOW* protocol_info, - HANDLE iocp) { - SOCKET socket; - - socket = WSASocketW(protocol_info->iAddressFamily, - protocol_info->iSocketType, - protocol_info->iProtocol, - protocol_info, - 0, - WSA_FLAG_OVERLAPPED); - if (socket == INVALID_SOCKET) - return_error(-1); - - if (!SetHandleInformation((HANDLE) socket, HANDLE_FLAG_INHERIT, 0)) - goto error; - - if (CreateIoCompletionPort((HANDLE) socket, iocp, 0, 0) == NULL) - goto error; - - poll_group->socket = socket; - return 0; - -error:; - DWORD error = GetLastError(); - closesocket(socket); - return_error(-1, error); -} - static poll_group_t* _poll_group_new(poll_group_allocator_t* pga) { poll_group_t* poll_group = malloc(sizeof *poll_group); if (poll_group == NULL) @@ -61,8 +31,8 @@ static poll_group_t* _poll_group_new(poll_group_allocator_t* pga) { queue_node_init(&poll_group->queue_node); poll_group->allocator = pga; - if (_poll_group_create_socket( - poll_group, &pga->protocol_info, pga->port_info->iocp) < 0) { + if (afd_create_driver_socket(pga->port_info->iocp, &poll_group->socket) < + 0) { free(poll_group); return NULL; } @@ -83,15 +53,13 @@ SOCKET poll_group_get_socket(poll_group_t* poll_group) { return poll_group->socket; } -poll_group_allocator_t* poll_group_allocator_new( - ep_port_t* port_info, const WSAPROTOCOL_INFOW* protocol_info) { +poll_group_allocator_t* poll_group_allocator_new(ep_port_t* port_info) { poll_group_allocator_t* pga = malloc(sizeof *pga); if (pga == NULL) return_error(NULL, ERROR_NOT_ENOUGH_MEMORY); queue_init(&pga->poll_group_queue); pga->port_info = port_info; - pga->protocol_info = *protocol_info; return pga; } diff --git a/src/poll-group.h b/src/poll-group.h index 51889d2..7345f32 100644 --- a/src/poll-group.h +++ b/src/poll-group.h @@ -11,7 +11,7 @@ typedef struct poll_group_allocator poll_group_allocator_t; typedef struct poll_group poll_group_t; WEPOLL_INTERNAL poll_group_allocator_t* poll_group_allocator_new( - ep_port_t* port_info, const WSAPROTOCOL_INFOW* protocol_info); + ep_port_t* port_info); WEPOLL_INTERNAL void poll_group_allocator_delete(poll_group_allocator_t* pga); WEPOLL_INTERNAL poll_group_t* poll_group_acquire(poll_group_allocator_t* pga); diff --git a/src/port.c b/src/port.c index 938db25..c05aafe 100644 --- a/src/port.c +++ b/src/port.c @@ -357,19 +357,16 @@ ep_sock_t* ep_port_find_socket(ep_port_t* port_info, SOCKET socket) { } static poll_group_allocator_t* _ep_port_get_poll_group_allocator( - ep_port_t* port_info, const WSAPROTOCOL_INFOW* protocol_info) { + ep_port_t* port_info) { if (port_info->poll_group_allocator == NULL) { - port_info->poll_group_allocator = - poll_group_allocator_new(port_info, protocol_info); + port_info->poll_group_allocator = poll_group_allocator_new(port_info); } return port_info->poll_group_allocator; } -poll_group_t* ep_port_acquire_poll_group( - ep_port_t* port_info, const WSAPROTOCOL_INFOW* protocol_info) { - poll_group_allocator_t* pga = - _ep_port_get_poll_group_allocator(port_info, protocol_info); +poll_group_t* ep_port_acquire_poll_group(ep_port_t* port_info) { + poll_group_allocator_t* pga = _ep_port_get_poll_group_allocator(port_info); return poll_group_acquire(pga); } diff --git a/src/port.h b/src/port.h index a7dea7e..4ff7c4c 100644 --- a/src/port.h +++ b/src/port.h @@ -39,8 +39,7 @@ WEPOLL_INTERNAL int ep_port_ctl(ep_port_t* port_info, SOCKET sock, struct epoll_event* ev); -WEPOLL_INTERNAL poll_group_t* ep_port_acquire_poll_group( - ep_port_t* port_info, const WSAPROTOCOL_INFOW* protocol_info); +WEPOLL_INTERNAL poll_group_t* ep_port_acquire_poll_group(ep_port_t* port_info); WEPOLL_INTERNAL void ep_port_release_poll_group(ep_port_t* port_info, poll_group_t* poll_group); diff --git a/src/sock.c b/src/sock.c index 2b3b678..be5c45d 100644 --- a/src/sock.c +++ b/src/sock.c @@ -183,7 +183,7 @@ ep_sock_t* ep_sock_new(ep_port_t* port_info, SOCKET socket) { if (afd_get_protocol_info(socket, &afd_socket, &protocol_info) < 0) return NULL; - poll_group = ep_port_acquire_poll_group(port_info, &protocol_info); + poll_group = ep_port_acquire_poll_group(port_info); if (poll_group == NULL) return NULL; diff --git a/src/ws.c b/src/ws.c index bcc63aa..e100de5 100644 --- a/src/ws.c +++ b/src/ws.c @@ -1,4 +1,8 @@ +#include +#include + #include "error.h" +#include "util.h" #include "win.h" #include "ws.h" @@ -6,6 +10,8 @@ #define SIO_BASE_HANDLE 0x48000022 #endif +#define _WS_INITIAL_CATALOG_BUFFER_SIZE 0x4000 /* 16kb. */ + int ws_global_init(void) { int r; WSADATA wsa_data; @@ -34,3 +40,30 @@ SOCKET ws_get_base_socket(SOCKET socket) { return base_socket; } + +/* Retrieves a copy of the winsock catalog. + * The infos pointer must be released by the caller with free(). + */ +ssize_t ws_get_protocol_catalog(WSAPROTOCOL_INFOW** infos_out) { + DWORD buffer_size = _WS_INITIAL_CATALOG_BUFFER_SIZE; + int count; + WSAPROTOCOL_INFOW* infos; + + for (;;) { + infos = malloc(buffer_size); + if (infos == NULL) + return_error(-1, ERROR_NOT_ENOUGH_MEMORY); + + count = WSAEnumProtocolsW(NULL, infos, &buffer_size); + if (count == SOCKET_ERROR) { + free(infos); + if (WSAGetLastError() == WSAENOBUFS) + continue; /* Try again with bigger buffer size. */ + else + return_error(-1); + } + + *infos_out = infos; + return count; + } +} diff --git a/src/ws.h b/src/ws.h index 86d50cd..7ab527a 100644 --- a/src/ws.h +++ b/src/ws.h @@ -2,10 +2,12 @@ #define WEPOLL_WS_H_ #include "internal.h" +#include "util.h" #include "win.h" WEPOLL_INTERNAL int ws_global_init(void); WEPOLL_INTERNAL SOCKET ws_get_base_socket(SOCKET socket); +WEPOLL_INTERNAL ssize_t ws_get_protocol_catalog(WSAPROTOCOL_INFOW** infos_out); #endif /* WEPOLL_WS_H_ */