diff --git a/src/afd.c b/src/afd.c index bfd68c3..a2f90ec 100644 --- a/src/afd.c +++ b/src/afd.c @@ -1,6 +1,7 @@ #include "afd.h" #include "error.h" #include "nt.h" +#include "util.h" #include "win.h" #define FILE_DEVICE_NETWORK 0x00000012 @@ -12,6 +13,10 @@ #define IOCTL_AFD_POLL _AFD_CONTROL_CODE(AFD_POLL, METHOD_BUFFERED) +#ifndef SIO_BASE_HANDLE +#define SIO_BASE_HANDLE 0x48000022 +#endif + int afd_poll(SOCKET driver_socket, AFD_POLL_INFO* poll_info, OVERLAPPED* overlapped) { @@ -81,3 +86,85 @@ int afd_poll(SOCKET driver_socket, else return_error(-1, we_map_ntstatus_to_win_error(status)); } + +static SOCKET _afd_get_base_socket(SOCKET socket) { + SOCKET base_socket; + DWORD bytes; + + if (WSAIoctl(socket, + SIO_BASE_HANDLE, + NULL, + 0, + &base_socket, + sizeof base_socket, + &bytes, + NULL, + NULL) == SOCKET_ERROR) + return_error(INVALID_SOCKET); + + return base_socket; +} + +static ssize_t _afd_get_protocol_info(SOCKET socket, + WSAPROTOCOL_INFOW* protocol_info) { + ssize_t id; + int opt_len; + + opt_len = sizeof *protocol_info; + if (getsockopt(socket, + SOL_SOCKET, + SO_PROTOCOL_INFOW, + (char*) protocol_info, + &opt_len) != 0) + return_error(-1); + + id = -1; + for (size_t i = 0; i < array_count(AFD_PROVIDER_GUID_LIST); i++) { + if (memcmp(&protocol_info->ProviderId, + &AFD_PROVIDER_GUID_LIST[i], + sizeof protocol_info->ProviderId) == 0) { + id = i; + break; + } + } + + /* Check if the protocol uses an msafd socket. */ + if (id < 0) + return_error(-1, ERROR_NOT_SUPPORTED); + + return id; +} + +EPOLL_INTERNAL ssize_t afd_get_protocol(SOCKET socket, + SOCKET* afd_socket_out, + WSAPROTOCOL_INFOW* protocol_info) { + ssize_t id; + SOCKET afd_socket; + + /* Try to get protocol information, assuming that the given socket is an AFD + * socket. This should almost always be the case, and if it is, that saves us + * a call to WSAIoctl(). */ + afd_socket = socket; + id = _afd_get_protocol_info(afd_socket, protocol_info); + + if (id < 0) { + /* If getting protocol information failed, it might be due to the socket + * not being an AFD socket. If so, attempt to fetch the underlying base + * socket, then try again to obtain protocol information. If that also + * fails, return the *original* error. */ + DWORD original_error = GetLastError(); + if (original_error != ERROR_NOT_SUPPORTED) + return_error(-1); + + afd_socket = _afd_get_base_socket(socket); + if (afd_socket == INVALID_SOCKET || afd_socket == socket) + return_error(-1, original_error); + + id = _afd_get_protocol_info(afd_socket, protocol_info); + if (id < 0) + return_error(-1, original_error); + } + + *afd_socket_out = afd_socket; + return id; +} diff --git a/src/afd.h b/src/afd.h index 01a0a82..9f4ee52 100644 --- a/src/afd.h +++ b/src/afd.h @@ -3,6 +3,7 @@ #include "internal.h" #include "ntstatus.h" +#include "util.h" #include "win.h" /* clang-format off */ @@ -55,6 +56,10 @@ EPOLL_INTERNAL int afd_poll(SOCKET driver_socket, AFD_POLL_INFO* poll_info, OVERLAPPED* overlapped); +EPOLL_INTERNAL ssize_t afd_get_protocol(SOCKET socket, + SOCKET* afd_socket_out, + WSAPROTOCOL_INFOW* protocol_info); + /* clang-format off */ static const GUID AFD_PROVIDER_GUID_LIST[] = { diff --git a/src/epoll-socket.c b/src/epoll-socket.c index dfc4d76..a7faa3b 100644 --- a/src/epoll-socket.c +++ b/src/epoll-socket.c @@ -10,10 +10,6 @@ #include "poll-group.h" #include "port.h" -#ifndef SIO_BASE_HANDLE -#define SIO_BASE_HANDLE 0x48000022 -#endif - #define _EP_EVENT_MASK 0xffff typedef struct _poll_req { @@ -160,76 +156,48 @@ static inline void _ep_sock_free(_ep_sock_private_t* sock_private) { free(sock_private); } -static int _get_related_sockets(ep_port_t* port_info, - SOCKET socket, - SOCKET* afd_socket_out, - poll_group_t** poll_group_out) { - SOCKET afd_socket; - poll_group_t* poll_group; - DWORD bytes; - - /* Try to obtain a base handle for the socket, so we can bypass LSPs - * that get in the way if we want to talk to the kernel directly. If - * it fails we try if we work with the original socket. Note that on - * windows XP/2k3 this will always fail since they don't support the - * SIO_BASE_HANDLE ioctl. - */ - afd_socket = socket; - WSAIoctl(socket, - SIO_BASE_HANDLE, - NULL, - 0, - &afd_socket, - sizeof afd_socket, - &bytes, - NULL, - NULL); - - poll_group = ep_port_acquire_poll_group(port_info, afd_socket); - if (poll_group == NULL) - return -1; - - *afd_socket_out = afd_socket; - *poll_group_out = poll_group; - - return 0; -} - -static int _ep_sock_set_socket(ep_port_t* port_info, - _ep_sock_private_t* sock_private, - SOCKET socket) { - if (socket == 0 || socket == INVALID_SOCKET) - return_error(-1, ERROR_INVALID_HANDLE); - - assert(sock_private->afd_socket == 0); - - if (_get_related_sockets(port_info, - socket, - &sock_private->afd_socket, - &sock_private->poll_group) < 0) - return -1; - - if (ep_port_add_socket(port_info, &sock_private->pub.tree_node, socket) < 0) - return -1; - - return 0; -} - ep_sock_t* ep_sock_new(ep_port_t* port_info, SOCKET socket) { - _ep_sock_private_t* sock_private = _ep_sock_alloc(); - if (sock_private == NULL) + SOCKET afd_socket; + ssize_t protocol_id; + WSAPROTOCOL_INFOW protocol_info; + poll_group_t* poll_group; + _ep_sock_private_t* sock_private; + + if (socket == 0 || socket == INVALID_SOCKET) + return_error(NULL, ERROR_INVALID_HANDLE); + + protocol_id = afd_get_protocol(socket, &afd_socket, &protocol_info); + if (protocol_id < 0) return NULL; + poll_group = + ep_port_acquire_poll_group(port_info, protocol_id, &protocol_info); + if (poll_group == NULL) + return NULL; + + sock_private = _ep_sock_alloc(); + if (sock_private == NULL) + goto err1; + memset(sock_private, 0, sizeof *sock_private); + + sock_private->afd_socket = afd_socket; + sock_private->poll_group = poll_group; + tree_node_init(&sock_private->pub.tree_node); queue_node_init(&sock_private->pub.queue_node); - if (_ep_sock_set_socket(port_info, sock_private, socket) < 0) { - _ep_sock_free(sock_private); - return NULL; - } + if (ep_port_add_socket(port_info, &sock_private->pub.tree_node, socket) < 0) + goto err2; return &sock_private->pub; + +err2: + _ep_sock_free(sock_private); +err1: + ep_port_release_poll_group(poll_group); + + return NULL; } void ep_sock_delete(ep_port_t* port_info, ep_sock_t* sock_info) { diff --git a/src/epoll.c b/src/epoll.c index d39bd60..b12c55e 100644 --- a/src/epoll.c +++ b/src/epoll.c @@ -254,48 +254,25 @@ int ep_port_del_socket(ep_port_t* port_info, tree_node_t* tree_node) { static poll_group_allocator_t* _get_poll_group_allocator( ep_port_t* port_info, - size_t index, + size_t protocol_id, const WSAPROTOCOL_INFOW* protocol_info) { - poll_group_allocator_t** pga = &port_info->poll_group_allocators[index]; + poll_group_allocator_t** pga; + assert(protocol_id < array_count(port_info->poll_group_allocators)); + + pga = &port_info->poll_group_allocators[protocol_id]; if (*pga == NULL) *pga = poll_group_allocator_new(port_info, protocol_info); return *pga; } -poll_group_t* ep_port_acquire_poll_group(ep_port_t* port_info, SOCKET socket) { - ssize_t index; - size_t i; - WSAPROTOCOL_INFOW protocol_info; - int len; - poll_group_allocator_t* pga; - - /* Obtain protocol information about the socket. */ - len = sizeof protocol_info; - if (getsockopt(socket, - SOL_SOCKET, - SO_PROTOCOL_INFOW, - (char*) &protocol_info, - &len) != 0) - return_error(NULL); - - index = -1; - for (i = 0; i < array_count(AFD_PROVIDER_GUID_LIST); i++) { - if (memcmp((void*) &protocol_info.ProviderId, - (void*) &AFD_PROVIDER_GUID_LIST[i], - sizeof protocol_info.ProviderId) == 0) { - index = i; - break; - } - } - - /* Check if the protocol uses an msafd socket. */ - if (index < 0) - return_error(NULL, ERROR_NOT_SUPPORTED); - - pga = _get_poll_group_allocator(port_info, index, &protocol_info); - +poll_group_t* ep_port_acquire_poll_group( + ep_port_t* port_info, + size_t protocol_id, + const WSAPROTOCOL_INFOW* protocol_info) { + poll_group_allocator_t* pga = + _get_poll_group_allocator(port_info, protocol_id, protocol_info); return poll_group_acquire(pga); } diff --git a/src/port.h b/src/port.h index ee84424..cf7ebcd 100644 --- a/src/port.h +++ b/src/port.h @@ -25,8 +25,10 @@ typedef struct ep_port { EPOLL_INTERNAL ep_port_t* ep_port_new(HANDLE iocp); EPOLL_INTERNAL int ep_port_delete(ep_port_t* port_info); -EPOLL_INTERNAL poll_group_t* ep_port_acquire_poll_group(ep_port_t* port_info, - SOCKET socket); +EPOLL_INTERNAL poll_group_t* ep_port_acquire_poll_group( + ep_port_t* port_info, + size_t protocol_id, + const WSAPROTOCOL_INFOW* protocol_info); EPOLL_INTERNAL void ep_port_release_poll_group(poll_group_t* poll_group); EPOLL_INTERNAL int ep_port_add_socket(ep_port_t* port_info,