wepoll/src/afd.c
2017-11-21 15:04:37 +01:00

171 lines
4.7 KiB
C

#include "afd.h"
#include "error.h"
#include "nt.h"
#include "util.h"
#include "win.h"
#define FILE_DEVICE_NETWORK 0x00000012
#define METHOD_BUFFERED 0
#define AFD_POLL 9
#define _AFD_CONTROL_CODE(operation, method) \
((FILE_DEVICE_NETWORK) << 12 | (operation << 2) | method)
#define IOCTL_AFD_POLL _AFD_CONTROL_CODE(AFD_POLL, METHOD_BUFFERED)
#ifndef SIO_BASE_HANDLE
#define SIO_BASE_HANDLE 0x48000022
#endif
int afd_poll(SOCKET driver_socket,
AFD_POLL_INFO* poll_info,
OVERLAPPED* overlapped) {
IO_STATUS_BLOCK iosb;
IO_STATUS_BLOCK* iosb_ptr;
HANDLE event = NULL;
void* apc_context;
NTSTATUS status;
if (overlapped != NULL) {
/* Overlapped operation. */
iosb_ptr = (IO_STATUS_BLOCK*) &overlapped->Internal;
event = overlapped->hEvent;
/* Do not report iocp completion if hEvent is tagged. */
if ((uintptr_t) event & 1) {
event = (HANDLE)((uintptr_t) event & ~(uintptr_t) 1);
apc_context = NULL;
} else {
apc_context = overlapped;
}
} else {
/* Blocking operation. */
iosb_ptr = &iosb;
event = CreateEventW(NULL, FALSE, FALSE, NULL);
if (event == NULL)
return_error(-1);
apc_context = NULL;
}
iosb_ptr->Status = STATUS_PENDING;
status = NtDeviceIoControlFile((HANDLE) driver_socket,
event,
NULL,
apc_context,
iosb_ptr,
IOCTL_AFD_POLL,
poll_info,
sizeof *poll_info,
poll_info,
sizeof *poll_info);
if (overlapped == NULL) {
/* If this is a blocking operation, wait for the event to become
* signaled, and then grab the real status from the io status block.
*/
if (status == STATUS_PENDING) {
DWORD r = WaitForSingleObject(event, INFINITE);
if (r == WAIT_FAILED) {
DWORD error = GetLastError();
CloseHandle(event);
return_error(-1, error);
}
status = iosb_ptr->Status;
}
CloseHandle(event);
}
if (status == STATUS_SUCCESS)
return 0;
else if (status == STATUS_PENDING)
return_error(-1, ERROR_IO_PENDING);
else
return_error(-1, RtlNtStatusToDosError(status));
}
static SOCKET _afd_get_base_socket(SOCKET socket) {
SOCKET base_socket;
DWORD bytes;
if (WSAIoctl(socket,
SIO_BASE_HANDLE,
NULL,
0,
&base_socket,
sizeof base_socket,
&bytes,
NULL,
NULL) == SOCKET_ERROR)
return_error(INVALID_SOCKET);
return base_socket;
}
static ssize_t _afd_get_protocol_info(SOCKET socket,
WSAPROTOCOL_INFOW* protocol_info) {
ssize_t id;
int opt_len;
opt_len = sizeof *protocol_info;
if (getsockopt(socket,
SOL_SOCKET,
SO_PROTOCOL_INFOW,
(char*) protocol_info,
&opt_len) != 0)
return_error(-1);
id = -1;
for (size_t i = 0; i < array_count(AFD_PROVIDER_GUID_LIST); i++) {
if (memcmp(&protocol_info->ProviderId,
&AFD_PROVIDER_GUID_LIST[i],
sizeof protocol_info->ProviderId) == 0) {
id = i;
break;
}
}
/* Check if the protocol uses an msafd socket. */
if (id < 0)
return_error(-1, ERROR_NOT_SUPPORTED);
return id;
}
WEPOLL_INTERNAL ssize_t afd_get_protocol(SOCKET socket,
SOCKET* afd_socket_out,
WSAPROTOCOL_INFOW* protocol_info) {
ssize_t id;
SOCKET afd_socket;
/* Try to get protocol information, assuming that the given socket is an AFD
* socket. This should almost always be the case, and if it is, that saves us
* a call to WSAIoctl(). */
afd_socket = socket;
id = _afd_get_protocol_info(afd_socket, protocol_info);
if (id < 0) {
/* If getting protocol information failed, it might be due to the socket
* not being an AFD socket. If so, attempt to fetch the underlying base
* socket, then try again to obtain protocol information. If that also
* fails, return the *original* error. */
DWORD original_error = GetLastError();
if (original_error != ERROR_NOT_SUPPORTED)
return_error(-1);
afd_socket = _afd_get_base_socket(socket);
if (afd_socket == INVALID_SOCKET || afd_socket == socket)
return_error(-1, original_error);
id = _afd_get_protocol_info(afd_socket, protocol_info);
if (id < 0)
return_error(-1, original_error);
}
*afd_socket_out = afd_socket;
return id;
}