afd: retrieve protocol info for afd driver socket on startup

This commit is contained in:
Bert Belder 2018-05-02 02:16:09 +02:00
parent c69f361564
commit 2789bad793
No known key found for this signature in database
GPG Key ID: 7A77887B2E2ED461
10 changed files with 152 additions and 47 deletions

100
src/afd.c
View File

@ -1,3 +1,6 @@
#include <malloc.h>
#include <stdlib.h>
#include "afd.h"
#include "error.h"
#include "nt.h"
@ -14,6 +17,8 @@
#define IOCTL_AFD_POLL _AFD_CONTROL_CODE(AFD_POLL, METHOD_BUFFERED)
#define _AFD_ANY_PROTOCOL -1
/* clang-format off */
static const GUID _AFD_PROVIDER_GUID_LIST[] = {
/* MSAFD Tcpip [TCP+UDP+RAW / IP] */
@ -30,6 +35,101 @@ static const GUID _AFD_PROVIDER_GUID_LIST[] = {
{0xb6, 0x55, 0x00, 0x80, 0x5f, 0x36, 0x42, 0xcc}}};
/* clang-format on */
/* This protocol info record is used by afd_create_driver_socket() to create
* sockets that can be used as the first argument to afd_poll(). It is
* populated on startup by afd_global_init().
*/
static WSAPROTOCOL_INFOW _afd_driver_socket_template;
static const WSAPROTOCOL_INFOW* _afd_find_protocol_info(
const WSAPROTOCOL_INFOW* infos, size_t infos_count, int protocol_id) {
size_t i, j;
for (i = 0; i < infos_count; i++) {
const WSAPROTOCOL_INFOW* info = &infos[i];
/* Apply protocol id filter. */
if (protocol_id != _AFD_ANY_PROTOCOL && protocol_id != info->iProtocol)
continue;
/* Filter out non-MSAFD protocols. */
for (j = 0; j < array_count(_AFD_PROVIDER_GUID_LIST); j++) {
if (memcmp(&info->ProviderId,
&_AFD_PROVIDER_GUID_LIST[j],
sizeof info->ProviderId) == 0)
return info;
}
}
return NULL; /* Not found. */
}
int afd_global_init(void) {
WSAPROTOCOL_INFOW* infos;
ssize_t infos_count;
const WSAPROTOCOL_INFOW* afd_info;
/* Load the winsock catalog. */
infos_count = ws_get_protocol_catalog(&infos);
if (infos_count < 0)
return_error(-1);
/* Find a WSAPROTOCOL_INDOW structure that we can use to create an MSAFD
* socket. Preferentially we pick a UDP socket, otherwise try TCP or any
* other type.
*/
do {
afd_info = _afd_find_protocol_info(infos, infos_count, IPPROTO_UDP);
if (afd_info != NULL)
break;
afd_info = _afd_find_protocol_info(infos, infos_count, IPPROTO_TCP);
if (afd_info != NULL)
break;
afd_info = _afd_find_protocol_info(infos, infos_count, _AFD_ANY_PROTOCOL);
if (afd_info != NULL)
break;
free(infos);
return_error(-1, WSAENETDOWN); /* No suitable protocol found. */
} while (0);
/* Copy found protocol information from the catalog to a static buffer. */
_afd_driver_socket_template = *afd_info;
free(infos);
return 0;
}
int afd_create_driver_socket(HANDLE iocp, SOCKET* driver_socket_out) {
SOCKET socket;
socket = WSASocketW(_afd_driver_socket_template.iAddressFamily,
_afd_driver_socket_template.iSocketType,
_afd_driver_socket_template.iProtocol,
&_afd_driver_socket_template,
0,
WSA_FLAG_OVERLAPPED);
if (socket == INVALID_SOCKET)
return_error(-1);
/* TODO: use WSA_FLAG_NOINHERIT on Windows versions that support it. */
if (!SetHandleInformation((HANDLE) socket, HANDLE_FLAG_INHERIT, 0))
goto error;
if (CreateIoCompletionPort((HANDLE) socket, iocp, 0, 0) == NULL)
goto error;
*driver_socket_out = socket;
return 0;
error:;
DWORD error = GetLastError();
closesocket(socket);
return_error(-1, error);
}
int afd_poll(SOCKET driver_socket,
AFD_POLL_INFO* poll_info,
OVERLAPPED* overlapped) {

View File

@ -52,6 +52,11 @@ typedef struct _AFD_POLL_INFO {
AFD_POLL_HANDLE_INFO Handles[1];
} AFD_POLL_INFO, *PAFD_POLL_INFO;
WEPOLL_INTERNAL int afd_global_init(void);
WEPOLL_INTERNAL int afd_create_driver_socket(HANDLE iocp,
SOCKET* driver_socket_out);
WEPOLL_INTERNAL int afd_poll(SOCKET driver_socket,
AFD_POLL_INFO* poll_info,
OVERLAPPED* overlapped);

View File

@ -1,5 +1,6 @@
#include <stdbool.h>
#include "afd.h"
#include "api.h"
#include "init.h"
#include "nt.h"
@ -18,7 +19,7 @@ static BOOL CALLBACK _init_once_callback(INIT_ONCE* once,
unused_var(context);
/* N.b. that initialization order matters here. */
if (ws_global_init() < 0 || nt_global_init() < 0 ||
if (ws_global_init() < 0 || nt_global_init() < 0 || afd_global_init() < 0 ||
reflock_global_init() < 0 || api_global_init() < 0)
return FALSE;

View File

@ -12,7 +12,6 @@ static const size_t _POLL_GROUP_MAX_SIZE = 32;
typedef struct poll_group_allocator {
ep_port_t* port_info;
queue_t poll_group_queue;
WSAPROTOCOL_INFOW protocol_info;
} poll_group_allocator_t;
typedef struct poll_group {
@ -22,35 +21,6 @@ typedef struct poll_group {
size_t group_size;
} poll_group_t;
static int _poll_group_create_socket(poll_group_t* poll_group,
WSAPROTOCOL_INFOW* protocol_info,
HANDLE iocp) {
SOCKET socket;
socket = WSASocketW(protocol_info->iAddressFamily,
protocol_info->iSocketType,
protocol_info->iProtocol,
protocol_info,
0,
WSA_FLAG_OVERLAPPED);
if (socket == INVALID_SOCKET)
return_error(-1);
if (!SetHandleInformation((HANDLE) socket, HANDLE_FLAG_INHERIT, 0))
goto error;
if (CreateIoCompletionPort((HANDLE) socket, iocp, 0, 0) == NULL)
goto error;
poll_group->socket = socket;
return 0;
error:;
DWORD error = GetLastError();
closesocket(socket);
return_error(-1, error);
}
static poll_group_t* _poll_group_new(poll_group_allocator_t* pga) {
poll_group_t* poll_group = malloc(sizeof *poll_group);
if (poll_group == NULL)
@ -61,8 +31,8 @@ static poll_group_t* _poll_group_new(poll_group_allocator_t* pga) {
queue_node_init(&poll_group->queue_node);
poll_group->allocator = pga;
if (_poll_group_create_socket(
poll_group, &pga->protocol_info, pga->port_info->iocp) < 0) {
if (afd_create_driver_socket(pga->port_info->iocp, &poll_group->socket) <
0) {
free(poll_group);
return NULL;
}
@ -83,15 +53,13 @@ SOCKET poll_group_get_socket(poll_group_t* poll_group) {
return poll_group->socket;
}
poll_group_allocator_t* poll_group_allocator_new(
ep_port_t* port_info, const WSAPROTOCOL_INFOW* protocol_info) {
poll_group_allocator_t* poll_group_allocator_new(ep_port_t* port_info) {
poll_group_allocator_t* pga = malloc(sizeof *pga);
if (pga == NULL)
return_error(NULL, ERROR_NOT_ENOUGH_MEMORY);
queue_init(&pga->poll_group_queue);
pga->port_info = port_info;
pga->protocol_info = *protocol_info;
return pga;
}

View File

@ -11,7 +11,7 @@ typedef struct poll_group_allocator poll_group_allocator_t;
typedef struct poll_group poll_group_t;
WEPOLL_INTERNAL poll_group_allocator_t* poll_group_allocator_new(
ep_port_t* port_info, const WSAPROTOCOL_INFOW* protocol_info);
ep_port_t* port_info);
WEPOLL_INTERNAL void poll_group_allocator_delete(poll_group_allocator_t* pga);
WEPOLL_INTERNAL poll_group_t* poll_group_acquire(poll_group_allocator_t* pga);

View File

@ -357,19 +357,16 @@ ep_sock_t* ep_port_find_socket(ep_port_t* port_info, SOCKET socket) {
}
static poll_group_allocator_t* _ep_port_get_poll_group_allocator(
ep_port_t* port_info, const WSAPROTOCOL_INFOW* protocol_info) {
ep_port_t* port_info) {
if (port_info->poll_group_allocator == NULL) {
port_info->poll_group_allocator =
poll_group_allocator_new(port_info, protocol_info);
port_info->poll_group_allocator = poll_group_allocator_new(port_info);
}
return port_info->poll_group_allocator;
}
poll_group_t* ep_port_acquire_poll_group(
ep_port_t* port_info, const WSAPROTOCOL_INFOW* protocol_info) {
poll_group_allocator_t* pga =
_ep_port_get_poll_group_allocator(port_info, protocol_info);
poll_group_t* ep_port_acquire_poll_group(ep_port_t* port_info) {
poll_group_allocator_t* pga = _ep_port_get_poll_group_allocator(port_info);
return poll_group_acquire(pga);
}

View File

@ -39,8 +39,7 @@ WEPOLL_INTERNAL int ep_port_ctl(ep_port_t* port_info,
SOCKET sock,
struct epoll_event* ev);
WEPOLL_INTERNAL poll_group_t* ep_port_acquire_poll_group(
ep_port_t* port_info, const WSAPROTOCOL_INFOW* protocol_info);
WEPOLL_INTERNAL poll_group_t* ep_port_acquire_poll_group(ep_port_t* port_info);
WEPOLL_INTERNAL void ep_port_release_poll_group(ep_port_t* port_info,
poll_group_t* poll_group);

View File

@ -183,7 +183,7 @@ ep_sock_t* ep_sock_new(ep_port_t* port_info, SOCKET socket) {
if (afd_get_protocol_info(socket, &afd_socket, &protocol_info) < 0)
return NULL;
poll_group = ep_port_acquire_poll_group(port_info, &protocol_info);
poll_group = ep_port_acquire_poll_group(port_info);
if (poll_group == NULL)
return NULL;

View File

@ -1,4 +1,8 @@
#include <malloc.h>
#include <stdlib.h>
#include "error.h"
#include "util.h"
#include "win.h"
#include "ws.h"
@ -6,6 +10,8 @@
#define SIO_BASE_HANDLE 0x48000022
#endif
#define _WS_INITIAL_CATALOG_BUFFER_SIZE 0x4000 /* 16kb. */
int ws_global_init(void) {
int r;
WSADATA wsa_data;
@ -34,3 +40,30 @@ SOCKET ws_get_base_socket(SOCKET socket) {
return base_socket;
}
/* Retrieves a copy of the winsock catalog.
* The infos pointer must be released by the caller with free().
*/
ssize_t ws_get_protocol_catalog(WSAPROTOCOL_INFOW** infos_out) {
DWORD buffer_size = _WS_INITIAL_CATALOG_BUFFER_SIZE;
int count;
WSAPROTOCOL_INFOW* infos;
for (;;) {
infos = malloc(buffer_size);
if (infos == NULL)
return_error(-1, ERROR_NOT_ENOUGH_MEMORY);
count = WSAEnumProtocolsW(NULL, infos, &buffer_size);
if (count == SOCKET_ERROR) {
free(infos);
if (WSAGetLastError() == WSAENOBUFS)
continue; /* Try again with bigger buffer size. */
else
return_error(-1);
}
*infos_out = infos;
return count;
}
}

View File

@ -2,10 +2,12 @@
#define WEPOLL_WS_H_
#include "internal.h"
#include "util.h"
#include "win.h"
WEPOLL_INTERNAL int ws_global_init(void);
WEPOLL_INTERNAL SOCKET ws_get_base_socket(SOCKET socket);
WEPOLL_INTERNAL ssize_t ws_get_protocol_catalog(WSAPROTOCOL_INFOW** infos_out);
#endif /* WEPOLL_WS_H_ */