diff --git a/allinone/epoll-all-in-one.c b/allinone/epoll-all-in-one.c index d694773..0a569da 100644 --- a/allinone/epoll-all-in-one.c +++ b/allinone/epoll-all-in-one.c @@ -220,6 +220,20 @@ typedef NTSTATUS* PNTSTATUS; : ((NTSTATUS)(((error) &0x0000FFFF) | (FACILITY_NTWIN32 << 16) | \ ERROR_SEVERITY_WARNING))) +#include + +#ifndef _SSIZE_T_DEFINED +#define SSIZE_T_DEFINED +typedef intptr_t ssize_t; +#endif + +#define array_count(a) (sizeof(a) / (sizeof((a)[0]))) + +#define container_of(ptr, type, member) \ + ((type*) ((char*) (ptr) -offsetof(type, member))) + +#define unused(v) ((void) (v)) + /* clang-format off */ #define AFD_NO_FAST_IO 0x00000001 @@ -270,6 +284,10 @@ EPOLL_INTERNAL int afd_poll(SOCKET driver_socket, AFD_POLL_INFO* poll_info, OVERLAPPED* overlapped); +EPOLL_INTERNAL ssize_t afd_get_protocol(SOCKET socket, + SOCKET* afd_socket_out, + WSAPROTOCOL_INFOW* protocol_info); + /* clang-format off */ static const GUID AFD_PROVIDER_GUID_LIST[] = { @@ -349,6 +367,10 @@ NTDLL_IMPORT_LIST(X) #define IOCTL_AFD_POLL _AFD_CONTROL_CODE(AFD_POLL, METHOD_BUFFERED) +#ifndef SIO_BASE_HANDLE +#define SIO_BASE_HANDLE 0x48000022 +#endif + int afd_poll(SOCKET driver_socket, AFD_POLL_INFO* poll_info, OVERLAPPED* overlapped) { @@ -419,65 +441,116 @@ int afd_poll(SOCKET driver_socket, return_error(-1, we_map_ntstatus_to_win_error(status)); } -EPOLL_INTERNAL int init(void); +static SOCKET _afd_get_base_socket(SOCKET socket) { + SOCKET base_socket; + DWORD bytes; + if (WSAIoctl(socket, + SIO_BASE_HANDLE, + NULL, + 0, + &base_socket, + sizeof base_socket, + &bytes, + NULL, + NULL) == SOCKET_ERROR) + return_error(INVALID_SOCKET); + + return base_socket; +} + +static ssize_t _afd_get_protocol_info(SOCKET socket, + WSAPROTOCOL_INFOW* protocol_info) { + ssize_t id; + int opt_len; + + opt_len = sizeof *protocol_info; + if (getsockopt(socket, + SOL_SOCKET, + SO_PROTOCOL_INFOW, + (char*) protocol_info, + &opt_len) != 0) + return_error(-1); + + id = -1; + for (size_t i = 0; i < array_count(AFD_PROVIDER_GUID_LIST); i++) { + if (memcmp(&protocol_info->ProviderId, + &AFD_PROVIDER_GUID_LIST[i], + sizeof protocol_info->ProviderId) == 0) { + id = i; + break; + } + } + + /* Check if the protocol uses an msafd socket. */ + if (id < 0) + return_error(-1, ERROR_NOT_SUPPORTED); + + return id; +} + +EPOLL_INTERNAL ssize_t afd_get_protocol(SOCKET socket, + SOCKET* afd_socket_out, + WSAPROTOCOL_INFOW* protocol_info) { + ssize_t id; + SOCKET afd_socket; + + /* Try to get protocol information, assuming that the given socket is an AFD + * socket. This should almost always be the case, and if it is, that saves us + * a call to WSAIoctl(). */ + afd_socket = socket; + id = _afd_get_protocol_info(afd_socket, protocol_info); + + if (id < 0) { + /* If getting protocol information failed, it might be due to the socket + * not being an AFD socket. If so, attempt to fetch the underlying base + * socket, then try again to obtain protocol information. If that also + * fails, return the *original* error. */ + DWORD original_error = GetLastError(); + if (original_error != ERROR_NOT_SUPPORTED) + return_error(-1); + + afd_socket = _afd_get_base_socket(socket); + if (afd_socket == INVALID_SOCKET || afd_socket == socket) + return_error(-1, original_error); + + id = _afd_get_protocol_info(afd_socket, protocol_info); + if (id < 0) + return_error(-1, original_error); + } + + *afd_socket_out = afd_socket; + return id; +} +#include +#include #include -#include typedef struct queue_node queue_node_t; + typedef struct queue_node { queue_node_t* prev; queue_node_t* next; } queue_node_t; + typedef struct queue { queue_node_t head; } queue_t; -EPOLL_INTERNAL inline void queue_node_init(queue_node_t* node) { - node->prev = node; - node->next = node; -} +EPOLL_INTERNAL void queue_init(queue_t* queue); +EPOLL_INTERNAL void queue_node_init(queue_node_t* node); -EPOLL_INTERNAL inline void queue_init(queue_t* queue) { - queue_node_init(&queue->head); -} +EPOLL_INTERNAL queue_node_t* queue_first(const queue_t* queue); +EPOLL_INTERNAL queue_node_t* queue_last(const queue_t* queue); -EPOLL_INTERNAL inline bool queue_enqueued(const queue_node_t* node) { - return node->prev != node; -} +EPOLL_INTERNAL void queue_prepend(queue_t* queue, queue_node_t* node); +EPOLL_INTERNAL void queue_append(queue_t* queue, queue_node_t* node); +EPOLL_INTERNAL void queue_move_first(queue_t* queue, queue_node_t* node); +EPOLL_INTERNAL void queue_move_last(queue_t* queue, queue_node_t* node); +EPOLL_INTERNAL void queue_remove(queue_node_t* node); -EPOLL_INTERNAL inline bool queue_empty(const queue_t* queue) { - return !queue_enqueued(&queue->head); -} - -EPOLL_INTERNAL inline queue_node_t* queue_first(const queue_t* queue) { - return !queue_empty(queue) ? queue->head.next : NULL; -} - -EPOLL_INTERNAL inline queue_node_t* queue_last(const queue_t* queue) { - return !queue_empty(queue) ? queue->head.prev : NULL; -} - -EPOLL_INTERNAL inline void queue_prepend(queue_t* queue, queue_node_t* node) { - node->next = queue->head.next; - node->prev = &queue->head; - node->next->prev = node; - queue->head.next = node; -} - -EPOLL_INTERNAL inline void queue_append(queue_t* queue, queue_node_t* node) { - node->next = &queue->head; - node->prev = queue->head.prev; - node->prev->next = node; - queue->head.prev = node; -} - -EPOLL_INTERNAL inline void queue_remove(queue_node_t* node) { - node->prev->next = node->next; - node->next->prev = node->prev; - node->prev = node; - node->next = node; -} +EPOLL_INTERNAL bool queue_empty(const queue_t* queue); +EPOLL_INTERNAL bool queue_enqueued(const queue_node_t* node); #ifdef __clang__ #define RB_UNUSED __attribute__((__unused__)) @@ -955,20 +1028,6 @@ EPOLL_INTERNAL int tree_del(tree_t* tree, tree_node_t* node); EPOLL_INTERNAL tree_node_t* tree_find(tree_t* tree, uintptr_t key); EPOLL_INTERNAL tree_node_t* tree_root(tree_t* tree); -#include - -#ifndef _SSIZE_T_DEFINED -#define SSIZE_T_DEFINED -typedef intptr_t ssize_t; -#endif - -#define array_count(a) (sizeof(a) / (sizeof((a)[0]))) - -#define container_of(ptr, type, member) \ - ((type*) ((char*) (ptr) -offsetof(type, member))) - -#define unused(v) ((void) (v)) - typedef struct ep_port ep_port_t; typedef struct poll_req poll_req_t; @@ -982,7 +1041,7 @@ EPOLL_INTERNAL void ep_sock_delete(ep_port_t* port_info, ep_sock_t* sock_info); EPOLL_INTERNAL void ep_sock_force_delete(ep_port_t* port_info, ep_sock_t* sock_info); -EPOLL_INTERNAL ep_sock_t* ep_sock_find(tree_t* tree, SOCKET socket); +EPOLL_INTERNAL ep_sock_t* ep_sock_find_in_tree(tree_t* tree, SOCKET socket); EPOLL_INTERNAL ep_sock_t* ep_sock_from_overlapped(OVERLAPPED* overlapped); EPOLL_INTERNAL int ep_sock_set_event(ep_port_t* port_info, @@ -1021,15 +1080,26 @@ typedef struct ep_port { EPOLL_INTERNAL ep_port_t* ep_port_new(HANDLE iocp); EPOLL_INTERNAL int ep_port_delete(ep_port_t* port_info); -EPOLL_INTERNAL poll_group_t* ep_port_acquire_poll_group(ep_port_t* port_info, - SOCKET socket); +EPOLL_INTERNAL int ep_port_update_events(ep_port_t* port_info); +EPOLL_INTERNAL size_t ep_port_feed_events(ep_port_t* port_info, + OVERLAPPED_ENTRY* completion_list, + size_t completion_count, + struct epoll_event* event_list, + size_t max_event_count); + +EPOLL_INTERNAL poll_group_t* ep_port_acquire_poll_group( + ep_port_t* port_info, + size_t protocol_id, + const WSAPROTOCOL_INFOW* protocol_info); EPOLL_INTERNAL void ep_port_release_poll_group(poll_group_t* poll_group); EPOLL_INTERNAL int ep_port_add_socket(ep_port_t* port_info, - tree_node_t* tree_node, + ep_sock_t* sock_info, SOCKET socket); EPOLL_INTERNAL int ep_port_del_socket(ep_port_t* port_info, - tree_node_t* tree_node); + ep_sock_t* sock_info); +EPOLL_INTERNAL ep_sock_t* ep_port_find_socket(ep_port_t* port_info, + SOCKET socket); EPOLL_INTERNAL void ep_port_request_socket_update(ep_port_t* port_info, ep_sock_t* sock_info); @@ -1038,43 +1108,6 @@ EPOLL_INTERNAL void ep_port_clear_socket_update(ep_port_t* port_info, EPOLL_INTERNAL bool ep_port_is_socket_update_pending(ep_port_t* port_info, ep_sock_t* sock_info); -epoll_t epoll_create(void) { - ep_port_t* port_info; - HANDLE iocp; - - if (init() < 0) - return NULL; - - iocp = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0); - if (iocp == INVALID_HANDLE_VALUE) - return_error(NULL); - - port_info = ep_port_new(iocp); - if (port_info == NULL) { - CloseHandle(iocp); - return NULL; - } - - return (epoll_t) port_info; -} - -int epoll_close(epoll_t port_handle) { - ep_port_t* port_info; - - if (init() < 0) - return -1; - - port_info = (ep_port_t*) port_handle; - - return ep_port_delete(port_info); -} -#include -#include - -#ifndef SIO_BASE_HANDLE -#define SIO_BASE_HANDLE 0x48000022 -#endif - #define _EP_EVENT_MASK 0xffff typedef struct _poll_req { @@ -1221,76 +1254,48 @@ static inline void _ep_sock_free(_ep_sock_private_t* sock_private) { free(sock_private); } -static int _get_related_sockets(ep_port_t* port_info, - SOCKET socket, - SOCKET* afd_socket_out, - poll_group_t** poll_group_out) { - SOCKET afd_socket; - poll_group_t* poll_group; - DWORD bytes; - - /* Try to obtain a base handle for the socket, so we can bypass LSPs - * that get in the way if we want to talk to the kernel directly. If - * it fails we try if we work with the original socket. Note that on - * windows XP/2k3 this will always fail since they don't support the - * SIO_BASE_HANDLE ioctl. - */ - afd_socket = socket; - WSAIoctl(socket, - SIO_BASE_HANDLE, - NULL, - 0, - &afd_socket, - sizeof afd_socket, - &bytes, - NULL, - NULL); - - poll_group = ep_port_acquire_poll_group(port_info, afd_socket); - if (poll_group == NULL) - return -1; - - *afd_socket_out = afd_socket; - *poll_group_out = poll_group; - - return 0; -} - -static int _ep_sock_set_socket(ep_port_t* port_info, - _ep_sock_private_t* sock_private, - SOCKET socket) { - if (socket == 0 || socket == INVALID_SOCKET) - return_error(-1, ERROR_INVALID_HANDLE); - - assert(sock_private->afd_socket == 0); - - if (_get_related_sockets(port_info, - socket, - &sock_private->afd_socket, - &sock_private->poll_group) < 0) - return -1; - - if (ep_port_add_socket(port_info, &sock_private->pub.tree_node, socket) < 0) - return -1; - - return 0; -} - ep_sock_t* ep_sock_new(ep_port_t* port_info, SOCKET socket) { - _ep_sock_private_t* sock_private = _ep_sock_alloc(); - if (sock_private == NULL) + SOCKET afd_socket; + ssize_t protocol_id; + WSAPROTOCOL_INFOW protocol_info; + poll_group_t* poll_group; + _ep_sock_private_t* sock_private; + + if (socket == 0 || socket == INVALID_SOCKET) + return_error(NULL, ERROR_INVALID_HANDLE); + + protocol_id = afd_get_protocol(socket, &afd_socket, &protocol_info); + if (protocol_id < 0) return NULL; + poll_group = + ep_port_acquire_poll_group(port_info, protocol_id, &protocol_info); + if (poll_group == NULL) + return NULL; + + sock_private = _ep_sock_alloc(); + if (sock_private == NULL) + goto err1; + memset(sock_private, 0, sizeof *sock_private); + + sock_private->afd_socket = afd_socket; + sock_private->poll_group = poll_group; + tree_node_init(&sock_private->pub.tree_node); queue_node_init(&sock_private->pub.queue_node); - if (_ep_sock_set_socket(port_info, sock_private, socket) < 0) { - _ep_sock_free(sock_private); - return NULL; - } + if (ep_port_add_socket(port_info, &sock_private->pub, socket) < 0) + goto err2; return &sock_private->pub; + +err2: + _ep_sock_free(sock_private); +err1: + ep_port_release_poll_group(poll_group); + + return NULL; } void ep_sock_delete(ep_port_t* port_info, ep_sock_t* sock_info) { @@ -1306,7 +1311,7 @@ void ep_sock_delete(ep_port_t* port_info, ep_sock_t* sock_info) { sock_private->pending_events = 0; } - ep_port_del_socket(port_info, &sock_info->tree_node); + ep_port_del_socket(port_info, sock_info); ep_port_clear_socket_update(port_info, sock_info); ep_port_release_poll_group(sock_private->poll_group); sock_private->poll_group = NULL; @@ -1323,7 +1328,7 @@ void ep_sock_force_delete(ep_port_t* port_info, ep_sock_t* sock_info) { ep_sock_delete(port_info, sock_info); } -ep_sock_t* ep_sock_find(tree_t* tree, SOCKET socket) { +ep_sock_t* ep_sock_find_in_tree(tree_t* tree, SOCKET socket) { tree_node_t* tree_node = tree_find(tree, socket); if (tree_node == NULL) return NULL; @@ -1461,11 +1466,42 @@ int ep_sock_feed_event(ep_port_t* port_info, return ev_count; } -#define _EP_COMPLETION_LIST_LENGTH 64 +#include -typedef struct ep_port ep_port_t; -typedef struct poll_req poll_req_t; -typedef struct ep_sock ep_sock_t; +EPOLL_INTERNAL int init(void); + +#define _EPOLL_MAX_COMPLETION_COUNT 64 + +epoll_t epoll_create(void) { + ep_port_t* port_info; + HANDLE iocp; + + if (init() < 0) + return NULL; + + iocp = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0); + if (iocp == INVALID_HANDLE_VALUE) + return_error(NULL); + + port_info = ep_port_new(iocp); + if (port_info == NULL) { + CloseHandle(iocp); + return NULL; + } + + return (epoll_t) port_info; +} + +int epoll_close(epoll_t port_handle) { + ep_port_t* port_info; + + if (init() < 0) + return -1; + + port_info = (ep_port_t*) port_handle; + + return ep_port_delete(port_info); +} static int _ep_ctl_add(ep_port_t* port_info, uintptr_t socket, @@ -1485,7 +1521,7 @@ static int _ep_ctl_add(ep_port_t* port_info, static int _ep_ctl_mod(ep_port_t* port_info, uintptr_t socket, struct epoll_event* ev) { - ep_sock_t* sock_info = ep_sock_find(&port_info->sock_tree, socket); + ep_sock_t* sock_info = ep_port_find_socket(port_info, socket); if (sock_info == NULL) return -1; @@ -1496,7 +1532,7 @@ static int _ep_ctl_mod(ep_port_t* port_info, } static int _ep_ctl_del(ep_port_t* port_info, uintptr_t socket) { - ep_sock_t* sock_info = ep_sock_find(&port_info->sock_tree, socket); + ep_sock_t* sock_info = ep_port_find_socket(port_info, socket); if (sock_info == NULL) return -1; @@ -1526,46 +1562,6 @@ int epoll_ctl(epoll_t port_handle, return_error(-1, ERROR_INVALID_PARAMETER); } -static int _ep_port_update_events(ep_port_t* port_info) { - queue_t* update_queue = &port_info->update_queue; - - /* Walk the queue, submitting new poll requests for every socket that needs - * it. */ - while (!queue_empty(update_queue)) { - queue_node_t* queue_node = queue_first(update_queue); - ep_sock_t* sock_info = container_of(queue_node, ep_sock_t, queue_node); - - if (ep_sock_update(port_info, sock_info) < 0) - return -1; - - /* ep_sock_update() removes the socket from the update list if - * successfull. */ - } - - return 0; -} - -static size_t _ep_port_feed_events(ep_port_t* port_info, - OVERLAPPED_ENTRY* completion_list, - size_t completion_count, - struct epoll_event* event_list, - size_t max_event_count) { - if (completion_count > max_event_count) - abort(); - - size_t event_count = 0; - - for (size_t i = 0; i < completion_count; i++) { - OVERLAPPED* overlapped = completion_list[i].lpOverlapped; - ep_sock_t* sock_info = ep_sock_from_overlapped(overlapped); - struct epoll_event* ev = &event_list[event_count]; - - event_count += ep_sock_feed_event(port_info, sock_info, ev); - } - - return event_count; -} - int epoll_wait(epoll_t port_handle, struct epoll_event* events, int maxevents, @@ -1592,18 +1588,18 @@ int epoll_wait(epoll_t port_handle, } /* Compute how much overlapped entries can be dequeued at most. */ - if ((size_t) maxevents > _EP_COMPLETION_LIST_LENGTH) - maxevents = _EP_COMPLETION_LIST_LENGTH; + if ((size_t) maxevents > _EPOLL_MAX_COMPLETION_COUNT) + maxevents = _EPOLL_MAX_COMPLETION_COUNT; /* Dequeue completion packets until either at least one interesting event * has been discovered, or the timeout is reached. */ do { - OVERLAPPED_ENTRY completion_list[_EP_COMPLETION_LIST_LENGTH]; + OVERLAPPED_ENTRY completion_list[_EPOLL_MAX_COMPLETION_COUNT]; ULONG completion_count; ssize_t event_count; - if (_ep_port_update_events(port_info) < 0) + if (ep_port_update_events(port_info) < 0) return -1; BOOL r = GetQueuedCompletionStatusEx(port_info->iocp, @@ -1619,7 +1615,7 @@ int epoll_wait(epoll_t port_handle, return_error(-1); } - event_count = _ep_port_feed_events( + event_count = ep_port_feed_events( port_info, completion_list, completion_count, events, maxevents); if (event_count > 0) return (int) event_count; @@ -1634,139 +1630,6 @@ int epoll_wait(epoll_t port_handle, return 0; } -static ep_port_t* _ep_port_alloc(void) { - ep_port_t* port_info = malloc(sizeof *port_info); - if (port_info == NULL) - return_error(NULL, ERROR_NOT_ENOUGH_MEMORY); - - return port_info; -} - -static void _ep_port_free(ep_port_t* port) { - assert(port != NULL); - free(port); -} - -ep_port_t* ep_port_new(HANDLE iocp) { - ep_port_t* port_info; - - port_info = _ep_port_alloc(); - if (port_info == NULL) - return NULL; - - memset(port_info, 0, sizeof *port_info); - - port_info->iocp = iocp; - queue_init(&port_info->update_queue); - tree_init(&port_info->sock_tree); - - return port_info; -} - -int ep_port_delete(ep_port_t* port_info) { - tree_node_t* tree_node; - - if (!CloseHandle(port_info->iocp)) - return_error(-1); - port_info->iocp = NULL; - - while ((tree_node = tree_root(&port_info->sock_tree)) != NULL) { - ep_sock_t* sock_info = container_of(tree_node, ep_sock_t, tree_node); - ep_sock_force_delete(port_info, sock_info); - } - - for (size_t i = 0; i < array_count(port_info->poll_group_allocators); i++) { - poll_group_allocator_t* pga = port_info->poll_group_allocators[i]; - if (pga != NULL) - poll_group_allocator_delete(pga); - } - - _ep_port_free(port_info); - - return 0; -} - -int ep_port_add_socket(ep_port_t* port_info, - tree_node_t* tree_node, - SOCKET socket) { - return tree_add(&port_info->sock_tree, tree_node, socket); -} - -int ep_port_del_socket(ep_port_t* port_info, tree_node_t* tree_node) { - return tree_del(&port_info->sock_tree, tree_node); -} - -poll_group_allocator_t* _get_poll_group_allocator( - ep_port_t* port_info, - size_t index, - const WSAPROTOCOL_INFOW* protocol_info) { - poll_group_allocator_t** pga = &port_info->poll_group_allocators[index]; - - if (*pga == NULL) - *pga = poll_group_allocator_new(port_info, protocol_info); - - return *pga; -} - -poll_group_t* ep_port_acquire_poll_group(ep_port_t* port_info, SOCKET socket) { - ssize_t index; - size_t i; - WSAPROTOCOL_INFOW protocol_info; - int len; - poll_group_allocator_t* pga; - - /* Obtain protocol information about the socket. */ - len = sizeof protocol_info; - if (getsockopt(socket, - SOL_SOCKET, - SO_PROTOCOL_INFOW, - (char*) &protocol_info, - &len) != 0) - return_error(NULL); - - index = -1; - for (i = 0; i < array_count(AFD_PROVIDER_GUID_LIST); i++) { - if (memcmp((void*) &protocol_info.ProviderId, - (void*) &AFD_PROVIDER_GUID_LIST[i], - sizeof protocol_info.ProviderId) == 0) { - index = i; - break; - } - } - - /* Check if the protocol uses an msafd socket. */ - if (index < 0) - return_error(NULL, ERROR_NOT_SUPPORTED); - - pga = _get_poll_group_allocator(port_info, index, &protocol_info); - - return poll_group_acquire(pga); -} - -void ep_port_release_poll_group(poll_group_t* poll_group) { - poll_group_release(poll_group); -} - -bool ep_port_is_socket_update_pending(ep_port_t* port_info, - ep_sock_t* sock_info) { - unused(port_info); - return queue_enqueued(&sock_info->queue_node); -} - -void ep_port_request_socket_update(ep_port_t* port_info, - ep_sock_t* sock_info) { - if (ep_port_is_socket_update_pending(port_info, sock_info)) - return; - queue_append(&port_info->update_queue, &sock_info->queue_node); - assert(ep_port_is_socket_update_pending(port_info, sock_info)); -} - -void ep_port_clear_socket_update(ep_port_t* port_info, ep_sock_t* sock_info) { - if (!ep_port_is_socket_update_pending(port_info, sock_info)) - return; - queue_remove(&sock_info->queue_node); -} - /* clang-format off */ #define WE_ERROR_MAP(X) \ @@ -2572,11 +2435,8 @@ poll_group_t* poll_group_acquire(poll_group_allocator_t* pga) { if (poll_group == NULL) return NULL; - if (++poll_group->group_size == _POLL_GROUP_MAX_SIZE) { - /* Move to the front of the queue. */ - queue_remove(&poll_group->queue_node); - queue_prepend(&pga->poll_group_queue, &poll_group->queue_node); - } + if (++poll_group->group_size == _POLL_GROUP_MAX_SIZE) + queue_move_first(&pga->poll_group_queue, &poll_group->queue_node); return poll_group; } @@ -2587,13 +2447,224 @@ void poll_group_release(poll_group_t* poll_group) { poll_group->group_size--; assert(poll_group->group_size < _POLL_GROUP_MAX_SIZE); - /* Move to the back of the queue. */ - queue_remove(&poll_group->queue_node); - queue_append(&pga->poll_group_queue, &poll_group->queue_node); + queue_move_last(&pga->poll_group_queue, &poll_group->queue_node); /* TODO: free the poll_group_t* item at some point. */ } +static ep_port_t* _ep_port_alloc(void) { + ep_port_t* port_info = malloc(sizeof *port_info); + if (port_info == NULL) + return_error(NULL, ERROR_NOT_ENOUGH_MEMORY); + + return port_info; +} + +static void _ep_port_free(ep_port_t* port) { + assert(port != NULL); + free(port); +} + +ep_port_t* ep_port_new(HANDLE iocp) { + ep_port_t* port_info; + + port_info = _ep_port_alloc(); + if (port_info == NULL) + return NULL; + + memset(port_info, 0, sizeof *port_info); + + port_info->iocp = iocp; + queue_init(&port_info->update_queue); + tree_init(&port_info->sock_tree); + + return port_info; +} + +int ep_port_delete(ep_port_t* port_info) { + tree_node_t* tree_node; + + if (!CloseHandle(port_info->iocp)) + return_error(-1); + port_info->iocp = NULL; + + while ((tree_node = tree_root(&port_info->sock_tree)) != NULL) { + ep_sock_t* sock_info = container_of(tree_node, ep_sock_t, tree_node); + ep_sock_force_delete(port_info, sock_info); + } + + for (size_t i = 0; i < array_count(port_info->poll_group_allocators); i++) { + poll_group_allocator_t* pga = port_info->poll_group_allocators[i]; + if (pga != NULL) + poll_group_allocator_delete(pga); + } + + _ep_port_free(port_info); + + return 0; +} + +int ep_port_update_events(ep_port_t* port_info) { + queue_t* update_queue = &port_info->update_queue; + + /* Walk the queue, submitting new poll requests for every socket that needs + * it. */ + while (!queue_empty(update_queue)) { + queue_node_t* queue_node = queue_first(update_queue); + ep_sock_t* sock_info = container_of(queue_node, ep_sock_t, queue_node); + + if (ep_sock_update(port_info, sock_info) < 0) + return -1; + + /* ep_sock_update() removes the socket from the update list if + * successfull. */ + } + + return 0; +} + +size_t ep_port_feed_events(ep_port_t* port_info, + OVERLAPPED_ENTRY* completion_list, + size_t completion_count, + struct epoll_event* event_list, + size_t max_event_count) { + if (completion_count > max_event_count) + abort(); + + size_t event_count = 0; + + for (size_t i = 0; i < completion_count; i++) { + OVERLAPPED* overlapped = completion_list[i].lpOverlapped; + ep_sock_t* sock_info = ep_sock_from_overlapped(overlapped); + struct epoll_event* ev = &event_list[event_count]; + + event_count += ep_sock_feed_event(port_info, sock_info, ev); + } + + return event_count; +} + +int ep_port_add_socket(ep_port_t* port_info, + ep_sock_t* sock_info, + SOCKET socket) { + return tree_add(&port_info->sock_tree, &sock_info->tree_node, socket); +} + +int ep_port_del_socket(ep_port_t* port_info, ep_sock_t* sock_info) { + return tree_del(&port_info->sock_tree, &sock_info->tree_node); +} + +ep_sock_t* ep_port_find_socket(ep_port_t* port_info, SOCKET socket) { + return ep_sock_find_in_tree(&port_info->sock_tree, socket); +} + +static poll_group_allocator_t* _ep_port_get_poll_group_allocator( + ep_port_t* port_info, + size_t protocol_id, + const WSAPROTOCOL_INFOW* protocol_info) { + poll_group_allocator_t** pga; + + assert(protocol_id < array_count(port_info->poll_group_allocators)); + + pga = &port_info->poll_group_allocators[protocol_id]; + if (*pga == NULL) + *pga = poll_group_allocator_new(port_info, protocol_info); + + return *pga; +} + +poll_group_t* ep_port_acquire_poll_group( + ep_port_t* port_info, + size_t protocol_id, + const WSAPROTOCOL_INFOW* protocol_info) { + poll_group_allocator_t* pga = + _ep_port_get_poll_group_allocator(port_info, protocol_id, protocol_info); + return poll_group_acquire(pga); +} + +void ep_port_release_poll_group(poll_group_t* poll_group) { + poll_group_release(poll_group); +} + +void ep_port_request_socket_update(ep_port_t* port_info, + ep_sock_t* sock_info) { + if (ep_port_is_socket_update_pending(port_info, sock_info)) + return; + queue_append(&port_info->update_queue, &sock_info->queue_node); + assert(ep_port_is_socket_update_pending(port_info, sock_info)); +} + +void ep_port_clear_socket_update(ep_port_t* port_info, ep_sock_t* sock_info) { + if (!ep_port_is_socket_update_pending(port_info, sock_info)) + return; + queue_remove(&sock_info->queue_node); +} + +bool ep_port_is_socket_update_pending(ep_port_t* port_info, + ep_sock_t* sock_info) { + unused(port_info); + return queue_enqueued(&sock_info->queue_node); +} + +void queue_init(queue_t* queue) { + queue_node_init(&queue->head); +} + +void queue_node_init(queue_node_t* node) { + node->prev = node; + node->next = node; +} + +static inline void _queue_detach(queue_node_t* node) { + node->prev->next = node->next; + node->next->prev = node->prev; +} + +queue_node_t* queue_first(const queue_t* queue) { + return !queue_empty(queue) ? queue->head.next : NULL; +} + +queue_node_t* queue_last(const queue_t* queue) { + return !queue_empty(queue) ? queue->head.prev : NULL; +} + +void queue_prepend(queue_t* queue, queue_node_t* node) { + node->next = queue->head.next; + node->prev = &queue->head; + node->next->prev = node; + queue->head.next = node; +} + +void queue_append(queue_t* queue, queue_node_t* node) { + node->next = &queue->head; + node->prev = queue->head.prev; + node->prev->next = node; + queue->head.prev = node; +} + +void queue_move_first(queue_t* queue, queue_node_t* node) { + _queue_detach(node); + queue_prepend(queue, node); +} + +void queue_move_last(queue_t* queue, queue_node_t* node) { + _queue_detach(node); + queue_append(queue, node); +} + +void queue_remove(queue_node_t* node) { + _queue_detach(node); + queue_node_init(node); +} + +bool queue_empty(const queue_t* queue) { + return !queue_enqueued(&queue->head); +} + +bool queue_enqueued(const queue_node_t* node) { + return node->prev != node; +} + static inline int _tree_compare(tree_node_t* a, tree_node_t* b) { if (a->key < b->key) return -1; @@ -2616,18 +2687,11 @@ void tree_node_init(tree_node_t* node) { int tree_add(tree_t* tree, tree_node_t* node, uintptr_t key) { tree_node_t* existing_node; - if (key == 0) - return_error(-1, ERROR_INVALID_PARAMETER); - if (node->key != 0) - return_error(-1, ERROR_ALREADY_EXISTS); - node->key = key; existing_node = RB_INSERT(tree, tree, node); - if (existing_node != NULL) { - node->key = 0; + if (existing_node != NULL) return_error(-1, ERROR_ALREADY_EXISTS); - } return 0; } @@ -2635,9 +2699,6 @@ int tree_add(tree_t* tree, tree_node_t* node, uintptr_t key) { int tree_del(tree_t* tree, tree_node_t* node) { tree_node_t* removed_node; - if (node->key == 0) - return_error(-1, ERROR_NOT_FOUND); - removed_node = RB_REMOVE(tree, tree, node); if (removed_node == NULL) @@ -2645,8 +2706,6 @@ int tree_del(tree_t* tree, tree_node_t* node) { else assert(removed_node == node); - node->key = 0; - return 0; } @@ -2654,9 +2713,6 @@ tree_node_t* tree_find(tree_t* tree, uintptr_t key) { tree_node_t* node; tree_node_t lookup; - if (key == 0) - return_error(NULL, ERROR_INVALID_PARAMETER); - memset(&lookup, 0, sizeof lookup); lookup.key = key;