From 78a607b3793ff703426faf4398085a33974e1567 Mon Sep 17 00:00:00 2001 From: Bert Belder Date: Thu, 14 Sep 2017 01:46:56 +0200 Subject: [PATCH] all-in-one: check in the latest build --- allinone/epoll-all-in-one.c | 792 ++++++++++++++++++++++-------------- 1 file changed, 487 insertions(+), 305 deletions(-) diff --git a/allinone/epoll-all-in-one.c b/allinone/epoll-all-in-one.c index b4452fa..2cfb9fc 100644 --- a/allinone/epoll-all-in-one.c +++ b/allinone/epoll-all-in-one.c @@ -185,6 +185,10 @@ typedef NTSTATUS* PNTSTATUS; #define STATUS_PENDING ((NTSTATUS) 0x00000103L) #endif +#ifndef STATUS_CANCELLED +#define STATUS_CANCELLED ((NTSTATUS) 0xC0000120L) +#endif + #ifndef STATUS_SEVERITY_SUCCESS #define STATUS_SEVERITY_SUCCESS 0x0 #endif @@ -263,8 +267,8 @@ typedef struct _AFD_POLL_INFO { AFD_POLL_HANDLE_INFO Handles[1]; } AFD_POLL_INFO, *PAFD_POLL_INFO; -EPOLL_INTERNAL int afd_poll(SOCKET socket, - AFD_POLL_INFO* info, +EPOLL_INTERNAL int afd_poll(SOCKET driver_socket, + AFD_POLL_INFO* poll_info, OVERLAPPED* overlapped); /* clang-format off */ @@ -302,7 +306,7 @@ EPOLL_INTERNAL void we_set_win_error(DWORD error); #define return_error(value, ...) _return_error_helper(__VA_ARGS__ + 0, value) -EPOLL_INTERNAL int nt_initialize(void); +EPOLL_INTERNAL int nt_init(void); typedef struct _IO_STATUS_BLOCK { union { @@ -346,7 +350,9 @@ NTDLL_IMPORT_LIST(X) #define IOCTL_AFD_POLL _AFD_CONTROL_CODE(AFD_POLL, METHOD_BUFFERED) -int afd_poll(SOCKET socket, AFD_POLL_INFO* info, OVERLAPPED* overlapped) { +int afd_poll(SOCKET driver_socket, + AFD_POLL_INFO* poll_info, + OVERLAPPED* overlapped) { IO_STATUS_BLOCK iosb; IO_STATUS_BLOCK* iosb_ptr; HANDLE event = NULL; @@ -376,16 +382,16 @@ int afd_poll(SOCKET socket, AFD_POLL_INFO* info, OVERLAPPED* overlapped) { } iosb_ptr->Status = STATUS_PENDING; - status = NtDeviceIoControlFile((HANDLE) socket, + status = NtDeviceIoControlFile((HANDLE) driver_socket, event, NULL, apc_context, iosb_ptr, IOCTL_AFD_POLL, - info, - sizeof *info, - info, - sizeof *info); + poll_info, + sizeof *poll_info, + poll_info, + sizeof *poll_info); if (overlapped == NULL) { /* If this is a blocking operation, wait for the event to become @@ -413,10 +419,10 @@ int afd_poll(SOCKET socket, AFD_POLL_INFO* info, OVERLAPPED* overlapped) { else return_error(-1, we_map_ntstatus_to_win_error(status)); } -#include -#include -#include +EPOLL_INTERNAL int init(void); + +#include #include typedef struct queue_node queue_node_t; @@ -972,13 +978,13 @@ typedef struct ep_sock { queue_node_t queue_node; } ep_sock_t; -EPOLL_INTERNAL ep_sock_t* ep_sock_new(ep_port_t* port_info); -EPOLL_INTERNAL int ep_sock_delete(ep_port_t* port_info, ep_sock_t* sock_info); +EPOLL_INTERNAL ep_sock_t* ep_sock_new(ep_port_t* port_info, SOCKET socket); +EPOLL_INTERNAL void ep_sock_delete(ep_port_t* port_info, ep_sock_t* sock_info); +EPOLL_INTERNAL void ep_sock_force_delete(ep_port_t* port_info, + ep_sock_t* sock_info); + EPOLL_INTERNAL ep_sock_t* ep_sock_find(tree_t* tree, SOCKET socket); -EPOLL_INTERNAL int ep_sock_set_socket(ep_port_t* port_info, - ep_sock_t* sock_info, - SOCKET socket); EPOLL_INTERNAL int ep_sock_set_event(ep_port_t* port_info, ep_sock_t* sock_info, const struct epoll_event* ev); @@ -993,6 +999,87 @@ EPOLL_INTERNAL void ep_sock_register_poll_req(ep_port_t* port_info, EPOLL_INTERNAL void ep_sock_unregister_poll_req(ep_port_t* port_info, ep_sock_t* sock_info); +typedef struct ep_port ep_port_t; +typedef struct poll_group_allocator poll_group_allocator_t; +typedef struct poll_group poll_group_t; + +EPOLL_INTERNAL poll_group_allocator_t* poll_group_allocator_new( + ep_port_t* port_info, const WSAPROTOCOL_INFOW* protocol_info); +EPOLL_INTERNAL void poll_group_allocator_delete(poll_group_allocator_t* pga); + +EPOLL_INTERNAL poll_group_t* poll_group_acquire(poll_group_allocator_t* pga); +EPOLL_INTERNAL void poll_group_release(poll_group_t* ds); + +EPOLL_INTERNAL SOCKET poll_group_get_socket(poll_group_t* poll_group); + +typedef struct ep_port ep_port_t; +typedef struct ep_sock ep_sock_t; + +typedef struct ep_port { + HANDLE iocp; + poll_group_allocator_t* + poll_group_allocators[array_count(AFD_PROVIDER_GUID_LIST)]; + tree_t sock_tree; + queue_t update_queue; + size_t poll_req_count; +} ep_port_t; + +EPOLL_INTERNAL ep_port_t* ep_port_new(HANDLE iocp); +EPOLL_INTERNAL int ep_port_delete(ep_port_t* port_info); + +EPOLL_INTERNAL poll_group_t* ep_port_acquire_poll_group(ep_port_t* port_info, + SOCKET socket); +EPOLL_INTERNAL void ep_port_release_poll_group(poll_group_t* poll_group); + +EPOLL_INTERNAL int ep_port_add_socket(ep_port_t* port_info, + tree_node_t* tree_node, + SOCKET socket); +EPOLL_INTERNAL int ep_port_del_socket(ep_port_t* port_info, + tree_node_t* tree_node); + +EPOLL_INTERNAL void ep_port_add_req(ep_port_t* port_info); +EPOLL_INTERNAL void ep_port_del_req(ep_port_t* port_info); + +EPOLL_INTERNAL void ep_port_request_socket_update(ep_port_t* port_info, + ep_sock_t* sock_info); +EPOLL_INTERNAL void ep_port_clear_socket_update(ep_port_t* port_info, + ep_sock_t* sock_info); +EPOLL_INTERNAL bool ep_port_is_socket_update_pending(ep_port_t* port_info, + ep_sock_t* sock_info); + +epoll_t epoll_create(void) { + ep_port_t* port_info; + HANDLE iocp; + + if (init() < 0) + return NULL; + + iocp = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0); + if (iocp == INVALID_HANDLE_VALUE) + return_error(NULL); + + port_info = ep_port_new(iocp); + if (port_info == NULL) { + CloseHandle(iocp); + return NULL; + } + + return (epoll_t) port_info; +} + +int epoll_close(epoll_t port_handle) { + ep_port_t* port_info; + + if (init() < 0) + return -1; + + port_info = (ep_port_t*) port_handle; + + return ep_port_delete(port_info); +} +#include +#include + typedef struct ep_port ep_port_t; typedef struct ep_sock ep_sock_t; typedef struct poll_req poll_req_t; @@ -1013,40 +1100,11 @@ EPOLL_INTERNAL int poll_req_submit(poll_req_t* poll_req, SOCKET socket, SOCKET driver_socket); +EPOLL_INTERNAL int poll_req_cancel(poll_req_t* poll_req, SOCKET group_socket); EPOLL_INTERNAL void poll_req_complete(const poll_req_t* poll_req, uint32_t* epoll_events_out, bool* socket_closed_out); -typedef struct ep_port ep_port_t; -typedef struct ep_sock ep_sock_t; - -typedef struct ep_port { - HANDLE iocp; - SOCKET driver_sockets[array_count(AFD_PROVIDER_GUID_LIST)]; - tree_t sock_tree; - queue_t update_queue; - size_t poll_req_count; -} ep_port_t; - -EPOLL_INTERNAL SOCKET ep_port_get_driver_socket(ep_port_t* port_info, - SOCKET socket); - -EPOLL_INTERNAL int ep_port_add_socket(ep_port_t* port_info, - tree_node_t* tree_node, - SOCKET socket); -EPOLL_INTERNAL int ep_port_del_socket(ep_port_t* port_info, - tree_node_t* tree_node); - -EPOLL_INTERNAL void ep_port_add_req(ep_port_t* port_info); -EPOLL_INTERNAL void ep_port_del_req(ep_port_t* port_info); - -EPOLL_INTERNAL void ep_port_request_socket_update(ep_port_t* port_info, - ep_sock_t* sock_info); -EPOLL_INTERNAL void ep_port_clear_socket_update(ep_port_t* port_info, - ep_sock_t* sock_info); -EPOLL_INTERNAL bool ep_port_is_socket_update_pending(ep_port_t* port_info, - ep_sock_t* sock_info); - #ifndef SIO_BASE_HANDLE #define SIO_BASE_HANDLE 0x48000022 #endif @@ -1058,13 +1116,14 @@ EPOLL_INTERNAL bool ep_port_is_socket_update_pending(ep_port_t* port_info, typedef struct _ep_sock_private { ep_sock_t pub; SOCKET afd_socket; - SOCKET driver_socket; + poll_group_t* poll_group; epoll_data_t user_data; poll_req_t* latest_poll_req; uint32_t user_events; uint32_t latest_poll_req_events; uint32_t poll_req_count; uint32_t flags; + bool poll_req_active; } _ep_sock_private_t; static inline _ep_sock_private_t* _ep_sock_private(ep_sock_t* sock_info) { @@ -1086,76 +1145,12 @@ static inline void _ep_sock_free(_ep_sock_private_t* sock_private) { free(sock_private); } -ep_sock_t* ep_sock_new(ep_port_t* port_info) { - _ep_sock_private_t* sock_private = _ep_sock_alloc(); - if (sock_private == NULL) - return NULL; - - unused(port_info); - - memset(sock_private, 0, sizeof *sock_private); - tree_node_init(&sock_private->pub.tree_node); - queue_node_init(&sock_private->pub.queue_node); - - return &sock_private->pub; -} - -void _ep_sock_maybe_free(_ep_sock_private_t* sock_private) { - /* The socket may still have pending overlapped requests that have yet to be - * reported by the completion port. If that's the case the memory can't be - * released yet. It'll be released later as ep_sock_unregister_poll_req() - * calls this function. - */ - if (_ep_sock_is_deleted(sock_private) && sock_private->poll_req_count == 0) - _ep_sock_free(sock_private); -} - -int ep_sock_delete(ep_port_t* port_info, ep_sock_t* sock_info) { - _ep_sock_private_t* sock_private = _ep_sock_private(sock_info); - - assert(!_ep_sock_is_deleted(sock_private)); - - ep_port_del_socket(port_info, &sock_info->tree_node); - ep_port_clear_socket_update(port_info, sock_info); - - sock_private->flags |= _EP_SOCK_DELETED; - - _ep_sock_maybe_free(sock_private); - - return 0; -} - -ep_sock_t* ep_sock_find(tree_t* tree, SOCKET socket) { - tree_node_t* tree_node = tree_find(tree, socket); - if (tree_node == NULL) - return NULL; - - return container_of(tree, ep_sock_t, tree_node); -} - -void ep_sock_register_poll_req(ep_port_t* port_info, ep_sock_t* sock_info) { - _ep_sock_private_t* sock_private = _ep_sock_private(sock_info); - - assert(!_ep_sock_is_deleted(sock_private)); - - ep_port_add_req(port_info); - sock_private->poll_req_count++; -} - -void ep_sock_unregister_poll_req(ep_port_t* port_info, ep_sock_t* sock_info) { - _ep_sock_private_t* sock_private = _ep_sock_private(sock_info); - - ep_port_del_req(port_info); - sock_private->poll_req_count--; - - _ep_sock_maybe_free(sock_private); -} - static int _get_related_sockets(ep_port_t* port_info, SOCKET socket, SOCKET* afd_socket_out, - SOCKET* driver_socket_out) { - SOCKET afd_socket, driver_socket; + poll_group_t** poll_group_out) { + SOCKET afd_socket; + poll_group_t* poll_group; DWORD bytes; /* Try to obtain a base handle for the socket, so we can bypass LSPs @@ -1175,38 +1170,115 @@ static int _get_related_sockets(ep_port_t* port_info, NULL, NULL); - driver_socket = ep_port_get_driver_socket(port_info, afd_socket); - if (driver_socket == INVALID_SOCKET) + poll_group = ep_port_acquire_poll_group(port_info, afd_socket); + if (poll_group == NULL) return -1; *afd_socket_out = afd_socket; - *driver_socket_out = driver_socket; + *poll_group_out = poll_group; return 0; } -int ep_sock_set_socket(ep_port_t* port_info, - ep_sock_t* sock_info, - SOCKET socket) { - _ep_sock_private_t* sock_private = _ep_sock_private(sock_info); - +static int _ep_sock_set_socket(ep_port_t* port_info, + _ep_sock_private_t* sock_private, + SOCKET socket) { if (socket == 0 || socket == INVALID_SOCKET) return_error(-1, ERROR_INVALID_HANDLE); - if (sock_private->afd_socket != 0) - return_error(-1, ERROR_ALREADY_ASSIGNED); + + assert(sock_private->afd_socket == 0); if (_get_related_sockets(port_info, socket, &sock_private->afd_socket, - &sock_private->driver_socket) < 0) + &sock_private->poll_group) < 0) return -1; - if (ep_port_add_socket(port_info, &sock_info->tree_node, socket) < 0) + if (ep_port_add_socket(port_info, &sock_private->pub.tree_node, socket) < 0) return -1; return 0; } +ep_sock_t* ep_sock_new(ep_port_t* port_info, SOCKET socket) { + _ep_sock_private_t* sock_private = _ep_sock_alloc(); + if (sock_private == NULL) + return NULL; + + unused(port_info); + + memset(sock_private, 0, sizeof *sock_private); + tree_node_init(&sock_private->pub.tree_node); + queue_node_init(&sock_private->pub.queue_node); + + if (_ep_sock_set_socket(port_info, sock_private, socket) < 0) { + _ep_sock_free(sock_private); + return NULL; + } + + return &sock_private->pub; +} + +void _ep_sock_maybe_free(_ep_sock_private_t* sock_private) { + /* The socket may still have pending overlapped requests that have yet to be + * reported by the completion port. If that's the case the memory can't be + * released yet. It'll be released later as ep_sock_unregister_poll_req() + * calls this function. + */ + if (_ep_sock_is_deleted(sock_private) && sock_private->poll_req_count == 0) + _ep_sock_free(sock_private); +} + +void ep_sock_delete(ep_port_t* port_info, ep_sock_t* sock_info) { + _ep_sock_private_t* sock_private = _ep_sock_private(sock_info); + + assert(!_ep_sock_is_deleted(sock_private)); + sock_private->flags |= _EP_SOCK_DELETED; + + ep_port_del_socket(port_info, &sock_info->tree_node); + ep_port_clear_socket_update(port_info, sock_info); + ep_port_release_poll_group(sock_private->poll_group); + sock_private->poll_group = NULL; + + _ep_sock_maybe_free(sock_private); +} + +void ep_sock_force_delete(ep_port_t* port_info, ep_sock_t* sock_info) { + _ep_sock_private_t* sock_private = _ep_sock_private(sock_info); + if (sock_private->latest_poll_req != NULL) + poll_req_delete(port_info, sock_info, sock_private->latest_poll_req); + assert(sock_private->poll_req_count == 0); + ep_sock_delete(port_info, sock_info); +} + +ep_sock_t* ep_sock_find(tree_t* tree, SOCKET socket) { + tree_node_t* tree_node = tree_find(tree, socket); + if (tree_node == NULL) + return NULL; + + return container_of(tree_node, ep_sock_t, tree_node); +} + +void ep_sock_register_poll_req(ep_port_t* port_info, ep_sock_t* sock_info) { + _ep_sock_private_t* sock_private = _ep_sock_private(sock_info); + + assert(!_ep_sock_is_deleted(sock_private)); + + ep_port_add_req(port_info); + sock_private->poll_req_count++; + assert(sock_private->poll_req_count == 1); +} + +void ep_sock_unregister_poll_req(ep_port_t* port_info, ep_sock_t* sock_info) { + _ep_sock_private_t* sock_private = _ep_sock_private(sock_info); + + ep_port_del_req(port_info); + sock_private->poll_req_count--; + assert(sock_private->poll_req_count == 0); + + _ep_sock_maybe_free(sock_private); +} + int ep_sock_set_event(ep_port_t* port_info, ep_sock_t* sock_info, const struct epoll_event* ev) { @@ -1218,7 +1290,7 @@ int ep_sock_set_event(ep_port_t* port_info, sock_private->user_events = events; sock_private->user_data = ev->data; - if (events & _EP_EVENT_MASK & ~(sock_private->latest_poll_req_events)) + if ((events & _EP_EVENT_MASK & ~(sock_private->latest_poll_req_events)) != 0) ep_port_request_socket_update(port_info, sock_info); return 0; @@ -1226,12 +1298,15 @@ int ep_sock_set_event(ep_port_t* port_info, static inline bool _is_latest_poll_req(_ep_sock_private_t* sock_private, poll_req_t* poll_req) { + assert(sock_private->latest_poll_req == poll_req || + sock_private->latest_poll_req == NULL); return poll_req == sock_private->latest_poll_req; } static inline void _clear_latest_poll_req(_ep_sock_private_t* sock_private) { sock_private->latest_poll_req = NULL; sock_private->latest_poll_req_events = 0; + sock_private->poll_req_active = false; } static inline void _set_latest_poll_req(_ep_sock_private_t* sock_private, @@ -1239,55 +1314,61 @@ static inline void _set_latest_poll_req(_ep_sock_private_t* sock_private, uint32_t epoll_events) { sock_private->latest_poll_req = poll_req; sock_private->latest_poll_req_events = epoll_events; -} - -static int _ep_submit_poll_req(ep_port_t* port_info, - _ep_sock_private_t* sock_private) { - poll_req_t* poll_req; - uint32_t epoll_events = sock_private->user_events; - - poll_req = poll_req_new(port_info, &sock_private->pub); - if (poll_req == NULL) - return -1; - - if (poll_req_submit(poll_req, - epoll_events, - sock_private->afd_socket, - sock_private->driver_socket) < 0) { - poll_req_delete(port_info, &sock_private->pub, poll_req); - return -1; - } - - _set_latest_poll_req(sock_private, poll_req, epoll_events); - - return 0; + sock_private->poll_req_active = true; } int ep_sock_update(ep_port_t* port_info, ep_sock_t* sock_info) { _ep_sock_private_t* sock_private = _ep_sock_private(sock_info); bool broken = false; + SOCKET driver_socket; assert(ep_port_is_socket_update_pending(port_info, sock_info)); + driver_socket = poll_group_get_socket(sock_private->poll_group); + /* Check if there are events registered that are not yet submitted. In * that case we need to submit another req. */ if ((sock_private->user_events & _EP_EVENT_MASK & - ~sock_private->latest_poll_req_events) == 0) + ~sock_private->latest_poll_req_events) == 0) { /* All the events the user is interested in are already being monitored - * by the latest poll request. */ - goto done; + * by the latest poll request. It might spuriously complete because of an + * event that we're no longer interested in; if that happens we just + * submit another poll request with the right event mask. + */ + assert(sock_private->latest_poll_req != NULL); - if (_ep_submit_poll_req(port_info, sock_private) < 0) { - if (GetLastError() == ERROR_INVALID_HANDLE) - /* The socket is broken. It will be dropped from the epoll set. */ - broken = true; - else - /* Another error occurred, which is propagated to the caller. */ + } else if (sock_private->latest_poll_req != NULL) { + /* A poll request is already pending. Cancel the old one first; when it + * completes, we'll submit the new one. */ + if (sock_private->poll_req_active) { + poll_req_cancel(sock_private->latest_poll_req, driver_socket); + sock_private->poll_req_active = false; + } + + } else { + poll_req_t* poll_req = poll_req_new(port_info, &sock_private->pub); + if (poll_req == NULL) return -1; + + if (poll_req_submit(poll_req, + sock_private->user_events, + sock_private->afd_socket, + driver_socket) < 0) { + poll_req_delete(port_info, &sock_private->pub, poll_req); + + if (GetLastError() == ERROR_INVALID_HANDLE) + /* The socket is broken. It will be dropped from the epoll set. */ + broken = true; + else + /* Another error occurred, which is propagated to the caller. */ + return -1; + } + + if (!broken) + _set_latest_poll_req(sock_private, poll_req, sock_private->user_events); } -done: ep_port_clear_socket_update(port_info, sock_info); /* If we saw an ERROR_INVALID_HANDLE error, drop the socket. */ @@ -1323,10 +1404,10 @@ int ep_sock_feed_event(ep_port_t* port_info, /* Filter events that the user didn't ask for. */ epoll_events &= sock_private->user_events; - /* Drop the socket if the EPOLLONESHOT flag is set and there are any events + /* Clear the event mask if EPOLLONESHOT is set and there are any events * to report. */ if (epoll_events != 0 && (sock_private->user_events & EPOLLONESHOT)) - drop_socket = true; + sock_private->user_events = EPOLLERR | EPOLLHUP; /* Fill the ev structure if there are any events to report. */ if (epoll_events != 0) { @@ -1348,29 +1429,20 @@ int ep_sock_feed_event(ep_port_t* port_info, return ev_count; } -#include - #define _EP_COMPLETION_LIST_LENGTH 64 typedef struct ep_port ep_port_t; typedef struct poll_req poll_req_t; typedef struct ep_sock ep_sock_t; -static int _ep_initialize(void); -static SOCKET _ep_create_driver_socket(HANDLE iocp, - WSAPROTOCOL_INFOW* protocol_info); - -static int _ep_initialized = 0; - static int _ep_ctl_add(ep_port_t* port_info, uintptr_t socket, struct epoll_event* ev) { - ep_sock_t* sock_info = ep_sock_new(port_info); + ep_sock_t* sock_info = ep_sock_new(port_info, socket); if (sock_info == NULL) return -1; - if (ep_sock_set_socket(port_info, sock_info, socket) < 0 || - ep_sock_set_event(port_info, sock_info, ev) < 0) { + if (ep_sock_set_event(port_info, sock_info, ev) < 0) { ep_sock_delete(port_info, sock_info); return -1; } @@ -1396,8 +1468,7 @@ static int _ep_ctl_del(ep_port_t* port_info, uintptr_t socket) { if (sock_info == NULL) return -1; - if (ep_sock_delete(port_info, sock_info) < 0) - return -1; + ep_sock_delete(port_info, sock_info); return 0; } @@ -1408,6 +1479,9 @@ int epoll_ctl(epoll_t port_handle, struct epoll_event* ev) { ep_port_t* port_info = (ep_port_t*) port_handle; + if (init() < 0) + return -1; + switch (op) { case EPOLL_CTL_ADD: return _ep_ctl_add(port_info, socket, ev); @@ -1468,6 +1542,9 @@ int epoll_wait(epoll_t port_handle, ULONGLONG due = 0; DWORD gqcs_timeout; + if (init() < 0) + return -1; + port_info = (ep_port_t*) port_handle; /* Compute the timeout for GetQueuedCompletionStatus, and the wait end @@ -1525,109 +1602,54 @@ int epoll_wait(epoll_t port_handle, return 0; } -epoll_t epoll_create(void) { - ep_port_t* port_info; - HANDLE iocp; - - /* If necessary, do global initialization first. This is totally not - * thread-safe at the moment. - */ - if (!_ep_initialized) { - if (_ep_initialize() < 0) - return NULL; - _ep_initialized = 1; - } - - port_info = malloc(sizeof *port_info); +static ep_port_t* _ep_port_alloc(void) { + ep_port_t* port_info = malloc(sizeof *port_info); if (port_info == NULL) return_error(NULL, ERROR_NOT_ENOUGH_MEMORY); - iocp = CreateIoCompletionPort(INVALID_HANDLE_VALUE, NULL, 0, 0); - if (iocp == INVALID_HANDLE_VALUE) { - free(port_info); - return_error(NULL); - } + return port_info; +} + +static void _ep_port_free(ep_port_t* port) { + assert(port != NULL); + free(port); +} + +ep_port_t* ep_port_new(HANDLE iocp) { + ep_port_t* port_info; + + port_info = _ep_port_alloc(); + if (port_info == NULL) + return NULL; + + memset(port_info, 0, sizeof *port_info); port_info->iocp = iocp; - port_info->poll_req_count = 0; - queue_init(&port_info->update_queue); - - memset(&port_info->driver_sockets, 0, sizeof port_info->driver_sockets); tree_init(&port_info->sock_tree); - return (epoll_t) port_info; + return port_info; } -int epoll_close(epoll_t port_handle) { - ep_port_t* port_info; +int ep_port_delete(ep_port_t* port_info) { tree_node_t* tree_node; - port_info = (ep_port_t*) port_handle; - - /* Close all peer sockets. This will make all pending io requests return. */ - for (size_t i = 0; i < array_count(port_info->driver_sockets); i++) { - SOCKET driver_socket = port_info->driver_sockets[i]; - if (driver_socket != 0 && driver_socket != INVALID_SOCKET) { - if (closesocket(driver_socket) != 0) - return_error(-1); - - port_info->driver_sockets[i] = 0; - } - } - - /* There is no list of io requests to free. And even if there was, just - * freeing them would be dangerous since the kernel might still alter - * the overlapped status contained in them. But since we are sure that - * all requests will soon return, just await them all. - */ - while (port_info->poll_req_count > 0) { - OVERLAPPED_ENTRY entries[64]; - DWORD result; - ULONG count, i; - - result = GetQueuedCompletionStatusEx(port_info->iocp, - entries, - array_count(entries), - &count, - INFINITE, - FALSE); - - if (!result) - return_error(-1); - - for (i = 0; i < count; i++) { - poll_req_t* poll_req = overlapped_to_poll_req(entries[i].lpOverlapped); - poll_req_delete(port_info, poll_req_get_sock_data(poll_req), poll_req); - } - } - - /* Remove all entries from the socket_state tree. */ - while ((tree_node = tree_root(&port_info->sock_tree)) != NULL) { - ep_sock_t* sock_info = container_of(tree_node, ep_sock_t, tree_node); - ep_sock_delete(port_info, sock_info); - } - - /* Close the I/O completion port. */ if (!CloseHandle(port_info->iocp)) return_error(-1); + port_info->iocp = NULL; - /* Finally, remove the port data. */ - free(port_info); + while ((tree_node = tree_root(&port_info->sock_tree)) != NULL) { + ep_sock_t* sock_info = container_of(tree_node, ep_sock_t, tree_node); + ep_sock_force_delete(port_info, sock_info); + } - return 0; -} + for (size_t i = 0; i < array_count(port_info->poll_group_allocators); i++) { + poll_group_allocator_t* pga = port_info->poll_group_allocators[i]; + if (pga != NULL) + poll_group_allocator_delete(pga); + } -static int _ep_initialize(void) { - int r; - WSADATA wsa_data; - - r = WSAStartup(MAKEWORD(2, 2), &wsa_data); - if (r != 0) - return_error(-1); - - if (nt_initialize() < 0) - return -1; + _ep_port_free(port_info); return 0; } @@ -1650,12 +1672,24 @@ void ep_port_del_req(ep_port_t* port_info) { port_info->poll_req_count--; } -SOCKET ep_port_get_driver_socket(ep_port_t* port_info, SOCKET socket) { +poll_group_allocator_t* _get_poll_group_allocator( + ep_port_t* port_info, + size_t index, + const WSAPROTOCOL_INFOW* protocol_info) { + poll_group_allocator_t** pga = &port_info->poll_group_allocators[index]; + + if (*pga == NULL) + *pga = poll_group_allocator_new(port_info, protocol_info); + + return *pga; +} + +poll_group_t* ep_port_acquire_poll_group(ep_port_t* port_info, SOCKET socket) { ssize_t index; size_t i; - SOCKET driver_socket; WSAPROTOCOL_INFOW protocol_info; int len; + poll_group_allocator_t* pga; /* Obtain protocol information about the socket. */ len = sizeof protocol_info; @@ -1664,7 +1698,7 @@ SOCKET ep_port_get_driver_socket(ep_port_t* port_info, SOCKET socket) { SO_PROTOCOL_INFOW, (char*) &protocol_info, &len) != 0) - return_error(INVALID_SOCKET); + return_error(NULL); index = -1; for (i = 0; i < array_count(AFD_PROVIDER_GUID_LIST); i++) { @@ -1678,46 +1712,15 @@ SOCKET ep_port_get_driver_socket(ep_port_t* port_info, SOCKET socket) { /* Check if the protocol uses an msafd socket. */ if (index < 0) - return_error(INVALID_SOCKET, ERROR_NOT_SUPPORTED); + return_error(NULL, ERROR_NOT_SUPPORTED); - /* If we didn't (try) to create a peer socket yet, try to make one. Don't - * try again if the peer socket creation failed earlier for the same - * protocol. - */ - driver_socket = port_info->driver_sockets[index]; - if (driver_socket == 0) { - driver_socket = _ep_create_driver_socket(port_info->iocp, &protocol_info); - port_info->driver_sockets[index] = driver_socket; - } + pga = _get_poll_group_allocator(port_info, index, &protocol_info); - return driver_socket; + return poll_group_acquire(pga); } -static SOCKET _ep_create_driver_socket(HANDLE iocp, - WSAPROTOCOL_INFOW* protocol_info) { - SOCKET socket = 0; - - socket = WSASocketW(protocol_info->iAddressFamily, - protocol_info->iSocketType, - protocol_info->iProtocol, - protocol_info, - 0, - WSA_FLAG_OVERLAPPED); - if (socket == INVALID_SOCKET) - return_error(INVALID_SOCKET); - - if (!SetHandleInformation((HANDLE) socket, HANDLE_FLAG_INHERIT, 0)) - goto error; - - if (CreateIoCompletionPort((HANDLE) socket, iocp, 0, 0) == NULL) - goto error; - - return socket; - -error:; - DWORD error = GetLastError(); - closesocket(socket); - return_error(INVALID_SOCKET, error); +void ep_port_release_poll_group(poll_group_t* poll_group) { + poll_group_release(poll_group); } bool ep_port_is_socket_update_pending(ep_port_t* port_info, @@ -2381,12 +2384,40 @@ void we_set_win_error(DWORD error) { errno = we_map_win_error_to_errno(error); } +static bool _initialized = false; + +static int _init_winsock(void) { + int r; + WSADATA wsa_data; + + r = WSAStartup(MAKEWORD(2, 2), &wsa_data); + if (r != 0) + return_error(-1); + + return 0; +} + +static int _init_once(void) { + if (_init_winsock() < 0 || nt_init() < 0) + return -1; + + _initialized = true; + return 0; +} + +int init(void) { + if (_initialized) + return 0; + + return _init_once(); +} + #define X(return_type, declarators, name, parameters) \ EPOLL_INTERNAL return_type(declarators* name) parameters = NULL; NTDLL_IMPORT_LIST(X) #undef X -int nt_initialize(void) { +int nt_init(void) { HMODULE ntdll; ntdll = GetModuleHandleW(L"ntdll.dll"); @@ -2403,6 +2434,142 @@ int nt_initialize(void) { return 0; } +static const size_t _DS_MAX_USERS = 32; + +typedef struct poll_group_allocator { + ep_port_t* port_info; + queue_t poll_group_queue; + WSAPROTOCOL_INFOW protocol_info; +} poll_group_allocator_t; + +typedef struct poll_group { + poll_group_allocator_t* allocator; + queue_node_t queue_node; + SOCKET socket; + size_t user_count; +} poll_group_t; + +static int _poll_group_create_socket(poll_group_t* poll_group, + WSAPROTOCOL_INFOW* protocol_info, + HANDLE iocp) { + SOCKET socket; + + socket = WSASocketW(protocol_info->iAddressFamily, + protocol_info->iSocketType, + protocol_info->iProtocol, + protocol_info, + 0, + WSA_FLAG_OVERLAPPED); + if (socket == INVALID_SOCKET) + return_error(-1); + + if (!SetHandleInformation((HANDLE) socket, HANDLE_FLAG_INHERIT, 0)) + goto error; + + if (CreateIoCompletionPort((HANDLE) socket, iocp, 0, 0) == NULL) + goto error; + + poll_group->socket = socket; + return 0; + +error:; + DWORD error = GetLastError(); + closesocket(socket); + return_error(-1, error); +} + +static poll_group_t* _poll_group_new(poll_group_allocator_t* pga) { + poll_group_t* poll_group = malloc(sizeof *poll_group); + if (poll_group == NULL) + return_error(NULL, ERROR_NOT_ENOUGH_MEMORY); + + memset(poll_group, 0, sizeof *poll_group); + + queue_node_init(&poll_group->queue_node); + poll_group->allocator = pga; + + if (_poll_group_create_socket( + poll_group, &pga->protocol_info, pga->port_info->iocp) < 0) { + free(poll_group); + return NULL; + } + + queue_append(&pga->poll_group_queue, &poll_group->queue_node); + + return poll_group; +} + +static void _poll_group_delete(poll_group_t* poll_group) { + assert(poll_group->user_count == 0); + closesocket(poll_group->socket); + queue_remove(&poll_group->queue_node); + free(poll_group); +} + +SOCKET poll_group_get_socket(poll_group_t* poll_group) { + return poll_group->socket; +} + +poll_group_allocator_t* poll_group_allocator_new( + ep_port_t* port_info, const WSAPROTOCOL_INFOW* protocol_info) { + poll_group_allocator_t* pga = malloc(sizeof *pga); + if (pga == NULL) + return_error(NULL, ERROR_NOT_ENOUGH_MEMORY); + + queue_init(&pga->poll_group_queue); + pga->port_info = port_info; + pga->protocol_info = *protocol_info; + + return pga; +} + +void poll_group_allocator_delete(poll_group_allocator_t* pga) { + queue_t* poll_group_queue = &pga->poll_group_queue; + + while (!queue_empty(poll_group_queue)) { + queue_node_t* queue_node = queue_first(poll_group_queue); + poll_group_t* poll_group = + container_of(queue_node, poll_group_t, queue_node); + _poll_group_delete(poll_group); + } + + free(pga); +} + +poll_group_t* poll_group_acquire(poll_group_allocator_t* pga) { + queue_t* queue = &pga->poll_group_queue; + poll_group_t* poll_group = + !queue_empty(queue) + ? container_of(queue_last(queue), poll_group_t, queue_node) + : NULL; + + if (poll_group == NULL || poll_group->user_count >= _DS_MAX_USERS) + poll_group = _poll_group_new(pga); + if (poll_group == NULL) + return NULL; + + if (++poll_group->user_count == _DS_MAX_USERS) { + /* Move to the front of the queue. */ + queue_remove(&poll_group->queue_node); + queue_prepend(&pga->poll_group_queue, &poll_group->queue_node); + } + + return poll_group; +} + +void poll_group_release(poll_group_t* poll_group) { + poll_group_allocator_t* pga = poll_group->allocator; + + poll_group->user_count--; + assert(poll_group->user_count < _DS_MAX_USERS); + + /* Move to the back of the queue. */ + queue_remove(&poll_group->queue_node); + queue_append(&pga->poll_group_queue, &poll_group->queue_node); + + /* TODO: free the poll_group_t* item at some point. */ +} + typedef struct poll_req { OVERLAPPED overlapped; AFD_POLL_INFO poll_info; @@ -2499,7 +2666,7 @@ int poll_req_submit(poll_req_t* poll_req, memset(&poll_req->overlapped, 0, sizeof poll_req->overlapped); - poll_req->poll_info.Exclusive = TRUE; + poll_req->poll_info.Exclusive = FALSE; poll_req->poll_info.NumberOfHandles = 1; poll_req->poll_info.Timeout.QuadPart = INT64_MAX; poll_req->poll_info.Handles[0].Handle = (HANDLE) socket; @@ -2514,6 +2681,20 @@ int poll_req_submit(poll_req_t* poll_req, return 0; } +int poll_req_cancel(poll_req_t* poll_req, SOCKET driver_socket) { + OVERLAPPED* overlapped = &poll_req->overlapped; + + if (CancelIoEx((HANDLE) driver_socket, overlapped)) { + DWORD error = GetLastError(); + if (error == ERROR_NOT_FOUND) + return 0; /* Already completed or canceled. */ + else + return_error(-1); + } + + return 0; +} + void poll_req_complete(const poll_req_t* poll_req, uint32_t* epoll_events_out, bool* socket_closed_out) { @@ -2522,9 +2703,10 @@ void poll_req_complete(const poll_req_t* poll_req, uint32_t epoll_events = 0; bool socket_closed = false; - if (!NT_SUCCESS(overlapped->Internal)) { - /* The overlapped request itself failed, there are no events to consider. - */ + if ((NTSTATUS) overlapped->Internal == STATUS_CANCELLED) { + /* The poll request was cancelled by CancelIoEx. */ + } else if (!NT_SUCCESS(overlapped->Internal)) { + /* The overlapped request itself failed in an unexpected way. */ epoll_events = EPOLLERR; } else if (poll_req->poll_info.NumberOfHandles < 1) { /* This overlapped request succeeded but didn't report any events. */