diff --git a/wepoll.c b/wepoll.c index a9c59bb..8d20a32 100644 --- a/wepoll.c +++ b/wepoll.c @@ -440,11 +440,11 @@ WEPOLL_INTERNAL int port_ctl(port_state_t* port_state, SOCKET sock, struct epoll_event* ev); -WEPOLL_INTERNAL int port_register_socket_handle(port_state_t* port_state, - sock_state_t* sock_state, - SOCKET socket); -WEPOLL_INTERNAL void port_unregister_socket_handle(port_state_t* port_state, - sock_state_t* sock_state); +WEPOLL_INTERNAL int port_register_socket(port_state_t* port_state, + sock_state_t* sock_state, + SOCKET socket); +WEPOLL_INTERNAL void port_unregister_socket(port_state_t* port_state, + sock_state_t* sock_state); WEPOLL_INTERNAL sock_state_t* port_find_socket(port_state_t* port_state, SOCKET socket); @@ -941,12 +941,12 @@ WEPOLL_INTERNAL queue_node_t* queue_last(const queue_t* queue); WEPOLL_INTERNAL void queue_prepend(queue_t* queue, queue_node_t* node); WEPOLL_INTERNAL void queue_append(queue_t* queue, queue_node_t* node); -WEPOLL_INTERNAL void queue_move_first(queue_t* queue, queue_node_t* node); -WEPOLL_INTERNAL void queue_move_last(queue_t* queue, queue_node_t* node); +WEPOLL_INTERNAL void queue_move_to_start(queue_t* queue, queue_node_t* node); +WEPOLL_INTERNAL void queue_move_to_end(queue_t* queue, queue_node_t* node); WEPOLL_INTERNAL void queue_remove(queue_node_t* node); -WEPOLL_INTERNAL bool queue_empty(const queue_t* queue); -WEPOLL_INTERNAL bool queue_enqueued(const queue_node_t* node); +WEPOLL_INTERNAL bool queue_is_empty(const queue_t* queue); +WEPOLL_INTERNAL bool queue_is_enqueued(const queue_node_t* node); static const size_t POLL_GROUP__MAX_GROUP_SIZE = 32; @@ -999,7 +999,7 @@ HANDLE poll_group_get_afd_helper_handle(poll_group_t* poll_group) { poll_group_t* poll_group_acquire(port_state_t* port_state) { queue_t* poll_group_queue = port_get_poll_group_queue(port_state); poll_group_t* poll_group = - !queue_empty(poll_group_queue) + !queue_is_empty(poll_group_queue) ? container_of( queue_last(poll_group_queue), poll_group_t, queue_node) : NULL; @@ -1011,7 +1011,7 @@ poll_group_t* poll_group_acquire(port_state_t* port_state) { return NULL; if (++poll_group->group_size == POLL_GROUP__MAX_GROUP_SIZE) - queue_move_first(poll_group_queue, &poll_group->queue_node); + queue_move_to_start(poll_group_queue, &poll_group->queue_node); return poll_group; } @@ -1023,7 +1023,7 @@ void poll_group_release(poll_group_t* poll_group) { poll_group->group_size--; assert(poll_group->group_size < POLL_GROUP__MAX_GROUP_SIZE); - queue_move_last(poll_group_queue, &poll_group->queue_node); + queue_move_to_end(poll_group_queue, &poll_group->queue_node); /* Poll groups are currently only freed when the epoll port is closed. */ } @@ -1161,7 +1161,7 @@ int port_delete(port_state_t* port_state) { poll_group_delete(poll_group); } - assert(queue_empty(&port_state->sock_update_queue)); + assert(queue_is_empty(&port_state->sock_update_queue)); DeleteCriticalSection(&port_state->lock); @@ -1175,7 +1175,7 @@ static int port__update_events(port_state_t* port_state) { /* Walk the queue, submitting new poll requests for every socket that needs * it. */ - while (!queue_empty(sock_update_queue)) { + while (!queue_is_empty(sock_update_queue)) { queue_node_t* queue_node = queue_first(sock_update_queue); sock_state_t* sock_state = sock_state_from_queue_node(queue_node); @@ -1392,9 +1392,9 @@ int port_ctl(port_state_t* port_state, return result; } -int port_register_socket_handle(port_state_t* port_state, - sock_state_t* sock_state, - SOCKET socket) { +int port_register_socket(port_state_t* port_state, + sock_state_t* sock_state, + SOCKET socket) { if (tree_add(&port_state->sock_tree, sock_state_to_tree_node(sock_state), socket) < 0) @@ -1402,8 +1402,8 @@ int port_register_socket_handle(port_state_t* port_state, return 0; } -void port_unregister_socket_handle(port_state_t* port_state, - sock_state_t* sock_state) { +void port_unregister_socket(port_state_t* port_state, + sock_state_t* sock_state) { tree_del(&port_state->sock_tree, sock_state_to_tree_node(sock_state)); } @@ -1416,7 +1416,7 @@ sock_state_t* port_find_socket(port_state_t* port_state, SOCKET socket) { void port_request_socket_update(port_state_t* port_state, sock_state_t* sock_state) { - if (queue_enqueued(sock_state_to_queue_node(sock_state))) + if (queue_is_enqueued(sock_state_to_queue_node(sock_state))) return; queue_append(&port_state->sock_update_queue, sock_state_to_queue_node(sock_state)); @@ -1425,14 +1425,14 @@ void port_request_socket_update(port_state_t* port_state, void port_cancel_socket_update(port_state_t* port_state, sock_state_t* sock_state) { unused_var(port_state); - if (!queue_enqueued(sock_state_to_queue_node(sock_state))) + if (!queue_is_enqueued(sock_state_to_queue_node(sock_state))) return; queue_remove(sock_state_to_queue_node(sock_state)); } void port_add_deleted_socket(port_state_t* port_state, sock_state_t* sock_state) { - if (queue_enqueued(sock_state_to_queue_node(sock_state))) + if (queue_is_enqueued(sock_state_to_queue_node(sock_state))) return; queue_append(&port_state->sock_deleted_queue, sock_state_to_queue_node(sock_state)); @@ -1441,7 +1441,7 @@ void port_add_deleted_socket(port_state_t* port_state, void port_remove_deleted_socket(port_state_t* port_state, sock_state_t* sock_state) { unused_var(port_state); - if (!queue_enqueued(sock_state_to_queue_node(sock_state))) + if (!queue_is_enqueued(sock_state_to_queue_node(sock_state))) return; queue_remove(sock_state_to_queue_node(sock_state)); } @@ -1478,11 +1478,11 @@ static inline void queue__detach_node(queue_node_t* node) { } queue_node_t* queue_first(const queue_t* queue) { - return !queue_empty(queue) ? queue->head.next : NULL; + return !queue_is_empty(queue) ? queue->head.next : NULL; } queue_node_t* queue_last(const queue_t* queue) { - return !queue_empty(queue) ? queue->head.prev : NULL; + return !queue_is_empty(queue) ? queue->head.prev : NULL; } void queue_prepend(queue_t* queue, queue_node_t* node) { @@ -1499,12 +1499,12 @@ void queue_append(queue_t* queue, queue_node_t* node) { queue->head.prev = node; } -void queue_move_first(queue_t* queue, queue_node_t* node) { +void queue_move_to_start(queue_t* queue, queue_node_t* node) { queue__detach_node(node); queue_prepend(queue, node); } -void queue_move_last(queue_t* queue, queue_node_t* node) { +void queue_move_to_end(queue_t* queue, queue_node_t* node) { queue__detach_node(node); queue_append(queue, node); } @@ -1514,11 +1514,11 @@ void queue_remove(queue_node_t* node) { queue_node_init(node); } -bool queue_empty(const queue_t* queue) { - return !queue_enqueued(&queue->head); +bool queue_is_empty(const queue_t* queue) { + return !queue_is_enqueued(&queue->head); } -bool queue_enqueued(const queue_node_t* node) { +bool queue_is_enqueued(const queue_node_t* node) { return node->prev != node; } @@ -1664,7 +1664,7 @@ sock_state_t* sock_new(port_state_t* port_state, SOCKET socket) { tree_node_init(&sock_state->tree_node); queue_node_init(&sock_state->queue_node); - if (port_register_socket_handle(port_state, sock_state, socket) < 0) + if (port_register_socket(port_state, sock_state, socket) < 0) goto err2; return sock_state; @@ -1685,7 +1685,7 @@ static int sock__delete(port_state_t* port_state, sock__cancel_poll(sock_state); port_cancel_socket_update(port_state, sock_state); - port_unregister_socket_handle(port_state, sock_state); + port_unregister_socket(port_state, sock_state); sock_state->delete_pending = true; } @@ -2187,6 +2187,10 @@ tree_node_t* tree_root(const tree_t* tree) { return tree->root; } +#ifndef SIO_BSP_HANDLE_POLL +#define SIO_BSP_HANDLE_POLL 0x4800001D +#endif + #ifndef SIO_BASE_HANDLE #define SIO_BASE_HANDLE 0x48000022 #endif @@ -2202,20 +2206,53 @@ int ws_global_init(void) { return 0; } -SOCKET ws_get_base_socket(SOCKET socket) { - SOCKET base_socket; +static inline SOCKET ws__ioctl_get_bsp_socket(SOCKET socket, DWORD ioctl) { + SOCKET bsp_socket; DWORD bytes; if (WSAIoctl(socket, - SIO_BASE_HANDLE, + ioctl, NULL, 0, - &base_socket, - sizeof base_socket, + &bsp_socket, + sizeof bsp_socket, &bytes, NULL, - NULL) == SOCKET_ERROR) - return_map_error(INVALID_SOCKET); - - return base_socket; + NULL) != SOCKET_ERROR) + return bsp_socket; + else + return INVALID_SOCKET; +} + +SOCKET ws_get_base_socket(SOCKET socket) { + SOCKET base_socket; + DWORD error; + + for (;;) { + base_socket = ws__ioctl_get_bsp_socket(socket, SIO_BASE_HANDLE); + if (base_socket != INVALID_SOCKET) + return base_socket; + + error = GetLastError(); + if (error == WSAENOTSOCK) + return_set_error(INVALID_SOCKET, error); + + /* Even though Microsoft documentation clearly states that LSPs should + * never intercept the `SIO_BASE_HANDLE` ioctl [1], Komodia based LSPs do + * so anyway, breaking it, with the apparent intention of preventing LSP + * bypass [2]. Fortunately they don't handle `SIO_BSP_HANDLE_POLL`, which + * we can use to obtain the socket associated with the next protocol chain + * entry. If this succeeds, loop around and call `SIO_BASE_HANDLE` again + * with the retrieved BSP socket to be sure that we actually got all the + * way to the base. + * [1] https://docs.microsoft.com/en-us/windows/win32/winsock/winsock-ioctls + * [2] https://www.komodia.com/newwiki/index.php?title=Komodia%27s_Redirector_bug_fixes#Version_2.2.2.6 + */ + + base_socket = ws__ioctl_get_bsp_socket(socket, SIO_BSP_HANDLE_POLL); + if (base_socket != INVALID_SOCKET && base_socket != socket) + socket = base_socket; + else + return_set_error(INVALID_SOCKET, error); + } }