netio.sys: Implement wsk_accept() function.
Signed-off-by: Paul Gofman <pgofman@codeweavers.com> Signed-off-by: Alexandre Julliard <julliard@winehq.org>
This commit is contained in:
parent
61cdff3c62
commit
2d7a304c45
|
@ -44,14 +44,41 @@ struct _WSK_CLIENT
|
|||
WSK_CLIENT_NPI *client_npi;
|
||||
};
|
||||
|
||||
struct listen_socket_callback_context
|
||||
{
|
||||
SOCKADDR *remote_address;
|
||||
const void *client_dispatch;
|
||||
void *client_context;
|
||||
char addr_buffer[2 * (sizeof(SOCKADDR) + 16)];
|
||||
SOCKET acceptor;
|
||||
};
|
||||
|
||||
struct wsk_socket_internal
|
||||
{
|
||||
WSK_SOCKET wsk_socket;
|
||||
SOCKET s;
|
||||
const void *client_dispatch;
|
||||
void *client_context;
|
||||
ULONG flags;
|
||||
ADDRESS_FAMILY address_family;
|
||||
USHORT socket_type;
|
||||
ULONG protocol;
|
||||
OVERLAPPED ovr;
|
||||
TP_WAIT *tp_wait;
|
||||
IRP *pending_irp;
|
||||
|
||||
CRITICAL_SECTION cs_socket;
|
||||
|
||||
union
|
||||
{
|
||||
struct listen_socket_callback_context listen_socket_callback_context;
|
||||
}
|
||||
callback_context;
|
||||
};
|
||||
|
||||
static LPFN_ACCEPTEX pAcceptEx;
|
||||
static const WSK_PROVIDER_CONNECTION_DISPATCH wsk_provider_connection_dispatch;
|
||||
|
||||
static inline struct wsk_socket_internal *wsk_socket_internal_from_wsk_socket(WSK_SOCKET *wsk_socket)
|
||||
{
|
||||
return CONTAINING_RECORD(wsk_socket, struct wsk_socket_internal, wsk_socket);
|
||||
|
@ -79,7 +106,7 @@ static NTSTATUS sock_error_to_ntstatus(DWORD err)
|
|||
case WSAEAFNOSUPPORT:
|
||||
case WSAEPROTOTYPE: return STATUS_NOT_SUPPORTED;
|
||||
case WSAENOPROTOOPT: return STATUS_INVALID_PARAMETER;
|
||||
case WSAEOPNOTSUPP: return STATUS_NOT_SUPPORTED;
|
||||
case WSAEOPNOTSUPP: return STATUS_NOT_IMPLEMENTED;
|
||||
case WSAEADDRINUSE: return STATUS_ADDRESS_ALREADY_ASSOCIATED;
|
||||
case WSAEADDRNOTAVAIL: return STATUS_INVALID_PARAMETER;
|
||||
case WSAECONNREFUSED: return STATUS_CONNECTION_REFUSED;
|
||||
|
@ -97,6 +124,26 @@ static NTSTATUS sock_error_to_ntstatus(DWORD err)
|
|||
}
|
||||
}
|
||||
|
||||
static inline void lock_socket(struct wsk_socket_internal *socket)
|
||||
{
|
||||
EnterCriticalSection(&socket->cs_socket);
|
||||
}
|
||||
|
||||
static inline void unlock_socket(struct wsk_socket_internal *socket)
|
||||
{
|
||||
LeaveCriticalSection(&socket->cs_socket);
|
||||
}
|
||||
|
||||
static void socket_init(struct wsk_socket_internal *socket, PTP_WAIT_CALLBACK socket_async_callback)
|
||||
{
|
||||
InitializeCriticalSection(&socket->cs_socket);
|
||||
if (socket_async_callback)
|
||||
{
|
||||
socket->ovr.hEvent = CreateEventA(NULL, FALSE, FALSE, NULL);
|
||||
socket->tp_wait = CreateThreadpoolWait(socket_async_callback, socket, NULL);
|
||||
}
|
||||
}
|
||||
|
||||
static void dispatch_irp(IRP *irp, NTSTATUS status)
|
||||
{
|
||||
irp->IoStatus.u.Status = status;
|
||||
|
@ -124,7 +171,33 @@ static NTSTATUS WINAPI wsk_close_socket(WSK_SOCKET *socket, IRP *irp)
|
|||
|
||||
TRACE("socket %p, irp %p.\n", socket, irp);
|
||||
|
||||
lock_socket(s);
|
||||
|
||||
if (s->tp_wait)
|
||||
{
|
||||
CancelIoEx((HANDLE)s->s, &s->ovr);
|
||||
unlock_socket(s);
|
||||
WaitForThreadpoolWaitCallbacks(s->tp_wait, FALSE);
|
||||
lock_socket(s);
|
||||
CloseThreadpoolWait(s->tp_wait);
|
||||
}
|
||||
|
||||
if (s->flags & WSK_FLAG_LISTEN_SOCKET && s->callback_context.listen_socket_callback_context.acceptor)
|
||||
closesocket(s->callback_context.listen_socket_callback_context.acceptor);
|
||||
|
||||
status = closesocket(s->s) ? sock_error_to_ntstatus(WSAGetLastError()) : STATUS_SUCCESS;
|
||||
|
||||
if (s->ovr.hEvent)
|
||||
CloseHandle(s->ovr.hEvent);
|
||||
|
||||
if (s->pending_irp)
|
||||
{
|
||||
s->pending_irp->IoStatus.Information = 0;
|
||||
dispatch_irp(s->pending_irp, STATUS_CANCELLED);
|
||||
}
|
||||
|
||||
unlock_socket(s);
|
||||
DeleteCriticalSection(&s->cs_socket);
|
||||
heap_free(socket);
|
||||
|
||||
irp->IoStatus.Information = 0;
|
||||
|
@ -146,6 +219,8 @@ static NTSTATUS WINAPI wsk_bind(WSK_SOCKET *socket, SOCKADDR *local_address, ULO
|
|||
|
||||
if (bind(s->s, local_address, sizeof(*local_address)))
|
||||
status = sock_error_to_ntstatus(WSAGetLastError());
|
||||
else if (s->flags & WSK_FLAG_LISTEN_SOCKET && listen(s->s, SOMAXCONN))
|
||||
status = sock_error_to_ntstatus(WSAGetLastError());
|
||||
else
|
||||
status = STATUS_SUCCESS;
|
||||
|
||||
|
@ -155,16 +230,147 @@ static NTSTATUS WINAPI wsk_bind(WSK_SOCKET *socket, SOCKADDR *local_address, ULO
|
|||
return STATUS_PENDING;
|
||||
}
|
||||
|
||||
static void create_accept_socket(struct wsk_socket_internal *socket)
|
||||
{
|
||||
struct listen_socket_callback_context *context
|
||||
= &socket->callback_context.listen_socket_callback_context;
|
||||
struct wsk_socket_internal *accept_socket;
|
||||
NTSTATUS status;
|
||||
|
||||
if (!(accept_socket = heap_alloc_zero(sizeof(*accept_socket))))
|
||||
{
|
||||
ERR("No memory.\n");
|
||||
status = STATUS_NO_MEMORY;
|
||||
socket->pending_irp->IoStatus.Information = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
TRACE("accept_socket %p.\n", accept_socket);
|
||||
accept_socket->wsk_socket.Dispatch = &wsk_provider_connection_dispatch;
|
||||
accept_socket->s = context->acceptor;
|
||||
accept_socket->client_dispatch = context->client_dispatch;
|
||||
accept_socket->client_context = context->client_context;
|
||||
accept_socket->socket_type = socket->socket_type;
|
||||
accept_socket->address_family = socket->address_family;
|
||||
accept_socket->protocol = socket->protocol;
|
||||
accept_socket->flags = WSK_FLAG_CONNECTION_SOCKET;
|
||||
socket_init(accept_socket, NULL);
|
||||
/* TODO: fill local and remote addresses. */
|
||||
|
||||
socket->pending_irp->IoStatus.Information = (ULONG_PTR)&accept_socket->wsk_socket;
|
||||
status = STATUS_SUCCESS;
|
||||
}
|
||||
TRACE("status %#x.\n", status);
|
||||
dispatch_irp(socket->pending_irp, status);
|
||||
socket->pending_irp = NULL;
|
||||
}
|
||||
|
||||
static void WINAPI accept_callback(TP_CALLBACK_INSTANCE *instance, void *socket_, TP_WAIT *wait,
|
||||
TP_WAIT_RESULT wait_result)
|
||||
{
|
||||
struct listen_socket_callback_context *context;
|
||||
struct wsk_socket_internal *socket = socket_;
|
||||
DWORD size;
|
||||
|
||||
TRACE("instance %p, socket %p, wait %p, wait_result %#x.\n", instance, socket, wait, wait_result);
|
||||
|
||||
lock_socket(socket);
|
||||
context = &socket->callback_context.listen_socket_callback_context;
|
||||
|
||||
if (GetOverlappedResult((HANDLE)socket->s, &socket->ovr, &size, FALSE))
|
||||
{
|
||||
create_accept_socket(socket);
|
||||
}
|
||||
else
|
||||
{
|
||||
closesocket(context->acceptor);
|
||||
context->acceptor = 0;
|
||||
socket->pending_irp->IoStatus.Information = 0;
|
||||
dispatch_irp(socket->pending_irp, socket->ovr.Internal);
|
||||
socket->pending_irp = NULL;
|
||||
}
|
||||
unlock_socket(socket);
|
||||
}
|
||||
|
||||
static BOOL WINAPI init_accept_functions(INIT_ONCE *once, void *param, void **context)
|
||||
{
|
||||
GUID acceptex_guid = WSAID_ACCEPTEX;
|
||||
SOCKET s = (SOCKET)param;
|
||||
DWORD size;
|
||||
|
||||
if (WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, &acceptex_guid, sizeof(acceptex_guid),
|
||||
&pAcceptEx, sizeof(pAcceptEx), &size, NULL, NULL))
|
||||
{
|
||||
ERR("Could not get AcceptEx address, error %u.\n", WSAGetLastError());
|
||||
return FALSE;
|
||||
}
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
static NTSTATUS WINAPI wsk_accept(WSK_SOCKET *listen_socket, ULONG flags, void *accept_socket_context,
|
||||
const WSK_CLIENT_CONNECTION_DISPATCH *accept_socket_dispatch, SOCKADDR *local_address,
|
||||
SOCKADDR *remote_address, IRP *irp)
|
||||
{
|
||||
FIXME("listen_socket %p, flags %#x, accept_socket_context %p, accept_socket_dispatch %p, "
|
||||
"local_address %p, remote_address %p, irp %p stub.\n",
|
||||
struct wsk_socket_internal *s = wsk_socket_internal_from_wsk_socket(listen_socket);
|
||||
static INIT_ONCE init_once = INIT_ONCE_STATIC_INIT;
|
||||
struct listen_socket_callback_context *context;
|
||||
SOCKET acceptor;
|
||||
NTSTATUS status;
|
||||
DWORD size;
|
||||
int error;
|
||||
|
||||
TRACE("listen_socket %p, flags %#x, accept_socket_context %p, accept_socket_dispatch %p, "
|
||||
"local_address %p, remote_address %p, irp %p.\n",
|
||||
listen_socket, flags, accept_socket_context, accept_socket_dispatch, local_address,
|
||||
remote_address, irp);
|
||||
|
||||
return STATUS_NOT_IMPLEMENTED;
|
||||
if (!irp)
|
||||
return STATUS_INVALID_PARAMETER;
|
||||
|
||||
if (!InitOnceExecuteOnce(&init_once, init_accept_functions, (void *)s->s, NULL))
|
||||
{
|
||||
status = STATUS_UNSUCCESSFUL;
|
||||
dispatch_irp(irp, status);
|
||||
return status;
|
||||
}
|
||||
|
||||
lock_socket(s);
|
||||
context = &s->callback_context.listen_socket_callback_context;
|
||||
if ((acceptor = WSASocketW(s->address_family, s->socket_type, s->protocol, NULL, 0, WSA_FLAG_OVERLAPPED))
|
||||
== INVALID_SOCKET)
|
||||
{
|
||||
status = sock_error_to_ntstatus(WSAGetLastError());
|
||||
dispatch_irp(irp, status);
|
||||
unlock_socket(s);
|
||||
return status;
|
||||
}
|
||||
|
||||
s->pending_irp = irp;
|
||||
context->remote_address = remote_address;
|
||||
context->client_dispatch = accept_socket_dispatch;
|
||||
context->client_context = accept_socket_context;
|
||||
context->acceptor = acceptor;
|
||||
|
||||
if (pAcceptEx(s->s, acceptor, context->addr_buffer, 0,
|
||||
sizeof(SOCKADDR) + 16, sizeof(SOCKADDR) + 16, &size, &s->ovr))
|
||||
{
|
||||
create_accept_socket(s);
|
||||
}
|
||||
else if ((error = WSAGetLastError()) == ERROR_IO_PENDING)
|
||||
{
|
||||
SetThreadpoolWait(s->tp_wait, s->ovr.hEvent, NULL);
|
||||
}
|
||||
else
|
||||
{
|
||||
closesocket(acceptor);
|
||||
context->acceptor = 0;
|
||||
irp->IoStatus.Information = 0;
|
||||
dispatch_irp(irp, sock_error_to_ntstatus(error));
|
||||
s->pending_irp = NULL;
|
||||
}
|
||||
unlock_socket(s);
|
||||
|
||||
return STATUS_PENDING;
|
||||
}
|
||||
|
||||
static NTSTATUS WINAPI wsk_inspect_complete(WSK_SOCKET *listen_socket, WSK_INSPECT_ID *inspect_id,
|
||||
|
@ -284,6 +490,7 @@ static NTSTATUS WINAPI wsk_socket(WSK_CLIENT *client, ADDRESS_FAMILY address_fam
|
|||
PETHREAD owning_thread, SECURITY_DESCRIPTOR *security_descriptor, IRP *irp)
|
||||
{
|
||||
struct wsk_socket_internal *socket;
|
||||
PTP_WAIT_CALLBACK async_callback;
|
||||
NTSTATUS status;
|
||||
SOCKET s;
|
||||
|
||||
|
@ -300,13 +507,13 @@ static NTSTATUS WINAPI wsk_socket(WSK_CLIENT *client, ADDRESS_FAMILY address_fam
|
|||
|
||||
irp->IoStatus.Information = 0;
|
||||
|
||||
if ((s = WSASocketW(address_family, socket_type, protocol, NULL, 0, 0)) == INVALID_SOCKET)
|
||||
if ((s = WSASocketW(address_family, socket_type, protocol, NULL, 0, WSA_FLAG_OVERLAPPED)) == INVALID_SOCKET)
|
||||
{
|
||||
status = sock_error_to_ntstatus(WSAGetLastError());
|
||||
goto done;
|
||||
}
|
||||
|
||||
if (!(socket = heap_alloc(sizeof(*socket))))
|
||||
if (!(socket = heap_alloc_zero(sizeof(*socket))))
|
||||
{
|
||||
status = STATUS_NO_MEMORY;
|
||||
closesocket(s);
|
||||
|
@ -316,11 +523,16 @@ static NTSTATUS WINAPI wsk_socket(WSK_CLIENT *client, ADDRESS_FAMILY address_fam
|
|||
socket->s = s;
|
||||
socket->client_dispatch = dispatch;
|
||||
socket->client_context = socket_context;
|
||||
socket->socket_type = socket_type;
|
||||
socket->flags = flags;
|
||||
socket->address_family = address_family;
|
||||
socket->protocol = protocol;
|
||||
|
||||
switch (flags)
|
||||
{
|
||||
case WSK_FLAG_LISTEN_SOCKET:
|
||||
socket->wsk_socket.Dispatch = &wsk_provider_listen_dispatch;
|
||||
async_callback = accept_callback;
|
||||
break;
|
||||
|
||||
case WSK_FLAG_CONNECTION_SOCKET:
|
||||
|
@ -335,6 +547,8 @@ static NTSTATUS WINAPI wsk_socket(WSK_CLIENT *client, ADDRESS_FAMILY address_fam
|
|||
goto done;
|
||||
}
|
||||
|
||||
socket_init(socket, async_callback);
|
||||
|
||||
irp->IoStatus.Information = (ULONG_PTR)&socket->wsk_socket;
|
||||
status = STATUS_SUCCESS;
|
||||
|
||||
|
|
Loading…
Reference in New Issue