afd: retrieve protocol info for afd driver socket on startup
This commit is contained in:
parent
c69f361564
commit
2789bad793
100
src/afd.c
100
src/afd.c
@ -1,3 +1,6 @@
|
||||
#include <malloc.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "afd.h"
|
||||
#include "error.h"
|
||||
#include "nt.h"
|
||||
@ -14,6 +17,8 @@
|
||||
|
||||
#define IOCTL_AFD_POLL _AFD_CONTROL_CODE(AFD_POLL, METHOD_BUFFERED)
|
||||
|
||||
#define _AFD_ANY_PROTOCOL -1
|
||||
|
||||
/* clang-format off */
|
||||
static const GUID _AFD_PROVIDER_GUID_LIST[] = {
|
||||
/* MSAFD Tcpip [TCP+UDP+RAW / IP] */
|
||||
@ -30,6 +35,101 @@ static const GUID _AFD_PROVIDER_GUID_LIST[] = {
|
||||
{0xb6, 0x55, 0x00, 0x80, 0x5f, 0x36, 0x42, 0xcc}}};
|
||||
/* clang-format on */
|
||||
|
||||
/* This protocol info record is used by afd_create_driver_socket() to create
|
||||
* sockets that can be used as the first argument to afd_poll(). It is
|
||||
* populated on startup by afd_global_init().
|
||||
*/
|
||||
static WSAPROTOCOL_INFOW _afd_driver_socket_template;
|
||||
|
||||
static const WSAPROTOCOL_INFOW* _afd_find_protocol_info(
|
||||
const WSAPROTOCOL_INFOW* infos, size_t infos_count, int protocol_id) {
|
||||
size_t i, j;
|
||||
|
||||
for (i = 0; i < infos_count; i++) {
|
||||
const WSAPROTOCOL_INFOW* info = &infos[i];
|
||||
|
||||
/* Apply protocol id filter. */
|
||||
if (protocol_id != _AFD_ANY_PROTOCOL && protocol_id != info->iProtocol)
|
||||
continue;
|
||||
|
||||
/* Filter out non-MSAFD protocols. */
|
||||
for (j = 0; j < array_count(_AFD_PROVIDER_GUID_LIST); j++) {
|
||||
if (memcmp(&info->ProviderId,
|
||||
&_AFD_PROVIDER_GUID_LIST[j],
|
||||
sizeof info->ProviderId) == 0)
|
||||
return info;
|
||||
}
|
||||
}
|
||||
|
||||
return NULL; /* Not found. */
|
||||
}
|
||||
|
||||
int afd_global_init(void) {
|
||||
WSAPROTOCOL_INFOW* infos;
|
||||
ssize_t infos_count;
|
||||
const WSAPROTOCOL_INFOW* afd_info;
|
||||
|
||||
/* Load the winsock catalog. */
|
||||
infos_count = ws_get_protocol_catalog(&infos);
|
||||
if (infos_count < 0)
|
||||
return_error(-1);
|
||||
|
||||
/* Find a WSAPROTOCOL_INDOW structure that we can use to create an MSAFD
|
||||
* socket. Preferentially we pick a UDP socket, otherwise try TCP or any
|
||||
* other type.
|
||||
*/
|
||||
do {
|
||||
afd_info = _afd_find_protocol_info(infos, infos_count, IPPROTO_UDP);
|
||||
if (afd_info != NULL)
|
||||
break;
|
||||
|
||||
afd_info = _afd_find_protocol_info(infos, infos_count, IPPROTO_TCP);
|
||||
if (afd_info != NULL)
|
||||
break;
|
||||
|
||||
afd_info = _afd_find_protocol_info(infos, infos_count, _AFD_ANY_PROTOCOL);
|
||||
if (afd_info != NULL)
|
||||
break;
|
||||
|
||||
free(infos);
|
||||
return_error(-1, WSAENETDOWN); /* No suitable protocol found. */
|
||||
} while (0);
|
||||
|
||||
/* Copy found protocol information from the catalog to a static buffer. */
|
||||
_afd_driver_socket_template = *afd_info;
|
||||
|
||||
free(infos);
|
||||
return 0;
|
||||
}
|
||||
|
||||
int afd_create_driver_socket(HANDLE iocp, SOCKET* driver_socket_out) {
|
||||
SOCKET socket;
|
||||
|
||||
socket = WSASocketW(_afd_driver_socket_template.iAddressFamily,
|
||||
_afd_driver_socket_template.iSocketType,
|
||||
_afd_driver_socket_template.iProtocol,
|
||||
&_afd_driver_socket_template,
|
||||
0,
|
||||
WSA_FLAG_OVERLAPPED);
|
||||
if (socket == INVALID_SOCKET)
|
||||
return_error(-1);
|
||||
|
||||
/* TODO: use WSA_FLAG_NOINHERIT on Windows versions that support it. */
|
||||
if (!SetHandleInformation((HANDLE) socket, HANDLE_FLAG_INHERIT, 0))
|
||||
goto error;
|
||||
|
||||
if (CreateIoCompletionPort((HANDLE) socket, iocp, 0, 0) == NULL)
|
||||
goto error;
|
||||
|
||||
*driver_socket_out = socket;
|
||||
return 0;
|
||||
|
||||
error:;
|
||||
DWORD error = GetLastError();
|
||||
closesocket(socket);
|
||||
return_error(-1, error);
|
||||
}
|
||||
|
||||
int afd_poll(SOCKET driver_socket,
|
||||
AFD_POLL_INFO* poll_info,
|
||||
OVERLAPPED* overlapped) {
|
||||
|
||||
@ -52,6 +52,11 @@ typedef struct _AFD_POLL_INFO {
|
||||
AFD_POLL_HANDLE_INFO Handles[1];
|
||||
} AFD_POLL_INFO, *PAFD_POLL_INFO;
|
||||
|
||||
WEPOLL_INTERNAL int afd_global_init(void);
|
||||
|
||||
WEPOLL_INTERNAL int afd_create_driver_socket(HANDLE iocp,
|
||||
SOCKET* driver_socket_out);
|
||||
|
||||
WEPOLL_INTERNAL int afd_poll(SOCKET driver_socket,
|
||||
AFD_POLL_INFO* poll_info,
|
||||
OVERLAPPED* overlapped);
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#include <stdbool.h>
|
||||
|
||||
#include "afd.h"
|
||||
#include "api.h"
|
||||
#include "init.h"
|
||||
#include "nt.h"
|
||||
@ -18,7 +19,7 @@ static BOOL CALLBACK _init_once_callback(INIT_ONCE* once,
|
||||
unused_var(context);
|
||||
|
||||
/* N.b. that initialization order matters here. */
|
||||
if (ws_global_init() < 0 || nt_global_init() < 0 ||
|
||||
if (ws_global_init() < 0 || nt_global_init() < 0 || afd_global_init() < 0 ||
|
||||
reflock_global_init() < 0 || api_global_init() < 0)
|
||||
return FALSE;
|
||||
|
||||
|
||||
@ -12,7 +12,6 @@ static const size_t _POLL_GROUP_MAX_SIZE = 32;
|
||||
typedef struct poll_group_allocator {
|
||||
ep_port_t* port_info;
|
||||
queue_t poll_group_queue;
|
||||
WSAPROTOCOL_INFOW protocol_info;
|
||||
} poll_group_allocator_t;
|
||||
|
||||
typedef struct poll_group {
|
||||
@ -22,35 +21,6 @@ typedef struct poll_group {
|
||||
size_t group_size;
|
||||
} poll_group_t;
|
||||
|
||||
static int _poll_group_create_socket(poll_group_t* poll_group,
|
||||
WSAPROTOCOL_INFOW* protocol_info,
|
||||
HANDLE iocp) {
|
||||
SOCKET socket;
|
||||
|
||||
socket = WSASocketW(protocol_info->iAddressFamily,
|
||||
protocol_info->iSocketType,
|
||||
protocol_info->iProtocol,
|
||||
protocol_info,
|
||||
0,
|
||||
WSA_FLAG_OVERLAPPED);
|
||||
if (socket == INVALID_SOCKET)
|
||||
return_error(-1);
|
||||
|
||||
if (!SetHandleInformation((HANDLE) socket, HANDLE_FLAG_INHERIT, 0))
|
||||
goto error;
|
||||
|
||||
if (CreateIoCompletionPort((HANDLE) socket, iocp, 0, 0) == NULL)
|
||||
goto error;
|
||||
|
||||
poll_group->socket = socket;
|
||||
return 0;
|
||||
|
||||
error:;
|
||||
DWORD error = GetLastError();
|
||||
closesocket(socket);
|
||||
return_error(-1, error);
|
||||
}
|
||||
|
||||
static poll_group_t* _poll_group_new(poll_group_allocator_t* pga) {
|
||||
poll_group_t* poll_group = malloc(sizeof *poll_group);
|
||||
if (poll_group == NULL)
|
||||
@ -61,8 +31,8 @@ static poll_group_t* _poll_group_new(poll_group_allocator_t* pga) {
|
||||
queue_node_init(&poll_group->queue_node);
|
||||
poll_group->allocator = pga;
|
||||
|
||||
if (_poll_group_create_socket(
|
||||
poll_group, &pga->protocol_info, pga->port_info->iocp) < 0) {
|
||||
if (afd_create_driver_socket(pga->port_info->iocp, &poll_group->socket) <
|
||||
0) {
|
||||
free(poll_group);
|
||||
return NULL;
|
||||
}
|
||||
@ -83,15 +53,13 @@ SOCKET poll_group_get_socket(poll_group_t* poll_group) {
|
||||
return poll_group->socket;
|
||||
}
|
||||
|
||||
poll_group_allocator_t* poll_group_allocator_new(
|
||||
ep_port_t* port_info, const WSAPROTOCOL_INFOW* protocol_info) {
|
||||
poll_group_allocator_t* poll_group_allocator_new(ep_port_t* port_info) {
|
||||
poll_group_allocator_t* pga = malloc(sizeof *pga);
|
||||
if (pga == NULL)
|
||||
return_error(NULL, ERROR_NOT_ENOUGH_MEMORY);
|
||||
|
||||
queue_init(&pga->poll_group_queue);
|
||||
pga->port_info = port_info;
|
||||
pga->protocol_info = *protocol_info;
|
||||
|
||||
return pga;
|
||||
}
|
||||
|
||||
@ -11,7 +11,7 @@ typedef struct poll_group_allocator poll_group_allocator_t;
|
||||
typedef struct poll_group poll_group_t;
|
||||
|
||||
WEPOLL_INTERNAL poll_group_allocator_t* poll_group_allocator_new(
|
||||
ep_port_t* port_info, const WSAPROTOCOL_INFOW* protocol_info);
|
||||
ep_port_t* port_info);
|
||||
WEPOLL_INTERNAL void poll_group_allocator_delete(poll_group_allocator_t* pga);
|
||||
|
||||
WEPOLL_INTERNAL poll_group_t* poll_group_acquire(poll_group_allocator_t* pga);
|
||||
|
||||
11
src/port.c
11
src/port.c
@ -357,19 +357,16 @@ ep_sock_t* ep_port_find_socket(ep_port_t* port_info, SOCKET socket) {
|
||||
}
|
||||
|
||||
static poll_group_allocator_t* _ep_port_get_poll_group_allocator(
|
||||
ep_port_t* port_info, const WSAPROTOCOL_INFOW* protocol_info) {
|
||||
ep_port_t* port_info) {
|
||||
if (port_info->poll_group_allocator == NULL) {
|
||||
port_info->poll_group_allocator =
|
||||
poll_group_allocator_new(port_info, protocol_info);
|
||||
port_info->poll_group_allocator = poll_group_allocator_new(port_info);
|
||||
}
|
||||
|
||||
return port_info->poll_group_allocator;
|
||||
}
|
||||
|
||||
poll_group_t* ep_port_acquire_poll_group(
|
||||
ep_port_t* port_info, const WSAPROTOCOL_INFOW* protocol_info) {
|
||||
poll_group_allocator_t* pga =
|
||||
_ep_port_get_poll_group_allocator(port_info, protocol_info);
|
||||
poll_group_t* ep_port_acquire_poll_group(ep_port_t* port_info) {
|
||||
poll_group_allocator_t* pga = _ep_port_get_poll_group_allocator(port_info);
|
||||
return poll_group_acquire(pga);
|
||||
}
|
||||
|
||||
|
||||
@ -39,8 +39,7 @@ WEPOLL_INTERNAL int ep_port_ctl(ep_port_t* port_info,
|
||||
SOCKET sock,
|
||||
struct epoll_event* ev);
|
||||
|
||||
WEPOLL_INTERNAL poll_group_t* ep_port_acquire_poll_group(
|
||||
ep_port_t* port_info, const WSAPROTOCOL_INFOW* protocol_info);
|
||||
WEPOLL_INTERNAL poll_group_t* ep_port_acquire_poll_group(ep_port_t* port_info);
|
||||
WEPOLL_INTERNAL void ep_port_release_poll_group(ep_port_t* port_info,
|
||||
poll_group_t* poll_group);
|
||||
|
||||
|
||||
@ -183,7 +183,7 @@ ep_sock_t* ep_sock_new(ep_port_t* port_info, SOCKET socket) {
|
||||
if (afd_get_protocol_info(socket, &afd_socket, &protocol_info) < 0)
|
||||
return NULL;
|
||||
|
||||
poll_group = ep_port_acquire_poll_group(port_info, &protocol_info);
|
||||
poll_group = ep_port_acquire_poll_group(port_info);
|
||||
if (poll_group == NULL)
|
||||
return NULL;
|
||||
|
||||
|
||||
33
src/ws.c
33
src/ws.c
@ -1,4 +1,8 @@
|
||||
#include <malloc.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "error.h"
|
||||
#include "util.h"
|
||||
#include "win.h"
|
||||
#include "ws.h"
|
||||
|
||||
@ -6,6 +10,8 @@
|
||||
#define SIO_BASE_HANDLE 0x48000022
|
||||
#endif
|
||||
|
||||
#define _WS_INITIAL_CATALOG_BUFFER_SIZE 0x4000 /* 16kb. */
|
||||
|
||||
int ws_global_init(void) {
|
||||
int r;
|
||||
WSADATA wsa_data;
|
||||
@ -34,3 +40,30 @@ SOCKET ws_get_base_socket(SOCKET socket) {
|
||||
|
||||
return base_socket;
|
||||
}
|
||||
|
||||
/* Retrieves a copy of the winsock catalog.
|
||||
* The infos pointer must be released by the caller with free().
|
||||
*/
|
||||
ssize_t ws_get_protocol_catalog(WSAPROTOCOL_INFOW** infos_out) {
|
||||
DWORD buffer_size = _WS_INITIAL_CATALOG_BUFFER_SIZE;
|
||||
int count;
|
||||
WSAPROTOCOL_INFOW* infos;
|
||||
|
||||
for (;;) {
|
||||
infos = malloc(buffer_size);
|
||||
if (infos == NULL)
|
||||
return_error(-1, ERROR_NOT_ENOUGH_MEMORY);
|
||||
|
||||
count = WSAEnumProtocolsW(NULL, infos, &buffer_size);
|
||||
if (count == SOCKET_ERROR) {
|
||||
free(infos);
|
||||
if (WSAGetLastError() == WSAENOBUFS)
|
||||
continue; /* Try again with bigger buffer size. */
|
||||
else
|
||||
return_error(-1);
|
||||
}
|
||||
|
||||
*infos_out = infos;
|
||||
return count;
|
||||
}
|
||||
}
|
||||
|
||||
2
src/ws.h
2
src/ws.h
@ -2,10 +2,12 @@
|
||||
#define WEPOLL_WS_H_
|
||||
|
||||
#include "internal.h"
|
||||
#include "util.h"
|
||||
#include "win.h"
|
||||
|
||||
WEPOLL_INTERNAL int ws_global_init(void);
|
||||
|
||||
WEPOLL_INTERNAL SOCKET ws_get_base_socket(SOCKET socket);
|
||||
WEPOLL_INTERNAL ssize_t ws_get_protocol_catalog(WSAPROTOCOL_INFOW** infos_out);
|
||||
|
||||
#endif /* WEPOLL_WS_H_ */
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user