netio.sys: Fill socket addresses when accepting connection.

Signed-off-by: Paul Gofman <pgofman@codeweavers.com>
Signed-off-by: Alexandre Julliard <julliard@winehq.org>
This commit is contained in:
Paul Gofman 2020-06-24 15:09:02 +03:00 committed by Alexandre Julliard
parent aec5ccfc06
commit 290e1a11b3
2 changed files with 41 additions and 12 deletions

View File

@ -46,6 +46,7 @@ struct _WSK_CLIENT
struct listen_socket_callback_context
{
SOCKADDR *local_address;
SOCKADDR *remote_address;
const void *client_dispatch;
void *client_context;
@ -53,15 +54,6 @@ struct listen_socket_callback_context
SOCKET acceptor;
};
struct connect_socket_callback_context
{
struct wsk_socket_internal *socket;
SOCKADDR *remote_address;
const void *client_dispatch;
void *client_context;
IRP *pending_irp;
};
#define MAX_PENDING_IO 10
struct wsk_pending_io
@ -96,6 +88,7 @@ struct wsk_socket_internal
};
static LPFN_ACCEPTEX pAcceptEx;
static LPFN_GETACCEPTEXSOCKADDRS pGetAcceptExSockaddrs;
static LPFN_CONNECTEX pConnectEx;
static const WSK_PROVIDER_CONNECTION_DISPATCH wsk_provider_connection_dispatch;
@ -319,6 +312,8 @@ static void create_accept_socket(struct wsk_socket_internal *socket, struct wsk_
{
struct listen_socket_callback_context *context
= &socket->callback_context.listen_socket_callback_context;
INT local_address_len, remote_address_len;
SOCKADDR *local_address, *remote_address;
struct wsk_socket_internal *accept_socket;
if (!(accept_socket = heap_alloc_zero(sizeof(*accept_socket))))
@ -338,7 +333,17 @@ static void create_accept_socket(struct wsk_socket_internal *socket, struct wsk_
accept_socket->protocol = socket->protocol;
accept_socket->flags = WSK_FLAG_CONNECTION_SOCKET;
socket_init(accept_socket);
/* TODO: fill local and remote addresses. */
pGetAcceptExSockaddrs(context->addr_buffer, 0, sizeof(SOCKADDR) + 16, sizeof(SOCKADDR) + 16,
&local_address, &local_address_len, &remote_address, &remote_address_len);
if (context->local_address)
memcpy(context->local_address, local_address,
min(sizeof(*context->local_address), local_address_len));
if (context->remote_address)
memcpy(context->remote_address, remote_address,
min(sizeof(*context->remote_address), remote_address_len));
dispatch_pending_io(io, STATUS_SUCCESS, (ULONG_PTR)&accept_socket->wsk_socket);
}
@ -373,6 +378,7 @@ static void WINAPI accept_callback(TP_CALLBACK_INSTANCE *instance, void *socket_
static BOOL WINAPI init_accept_functions(INIT_ONCE *once, void *param, void **context)
{
GUID get_acceptex_guid = WSAID_GETACCEPTEXSOCKADDRS;
GUID acceptex_guid = WSAID_ACCEPTEX;
SOCKET s = (SOCKET)param;
DWORD size;
@ -383,6 +389,14 @@ static BOOL WINAPI init_accept_functions(INIT_ONCE *once, void *param, void **co
ERR("Could not get AcceptEx address, error %u.\n", WSAGetLastError());
return FALSE;
}
if (WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, &get_acceptex_guid, sizeof(get_acceptex_guid),
&pGetAcceptExSockaddrs, sizeof(pGetAcceptExSockaddrs), &size, NULL, NULL))
{
ERR("Could not get AcceptEx address, error %u.\n", WSAGetLastError());
return FALSE;
}
return TRUE;
}
@ -430,6 +444,7 @@ static NTSTATUS WINAPI wsk_accept(WSK_SOCKET *listen_socket, ULONG flags, void *
return STATUS_PENDING;
}
context->local_address = local_address;
context->remote_address = remote_address;
context->client_dispatch = accept_socket_dispatch;
context->client_context = accept_socket_context;

View File

@ -177,10 +177,10 @@ static void test_wsk_listen_socket(void)
static const WSK_CLIENT_LISTEN_DISPATCH client_listen_dispatch;
const WSK_PROVIDER_CONNECTION_DISPATCH *accept_dispatch;
WSK_SOCKET *tcp_socket, *udp_socket, *accept_socket;
struct sockaddr_in addr, local_addr, remote_addr;
struct socket_context context;
WSK_BUF wsk_buf1, wsk_buf2;
void *buffer1, *buffer2;
struct sockaddr_in addr;
LARGE_INTEGER timeout;
MDL *mdl1, *mdl2;
NTSTATUS status;
@ -287,7 +287,10 @@ static void test_wsk_listen_socket(void)
IoReuseIrp(wsk_irp, STATUS_UNSUCCESSFUL);
IoSetCompletionRoutine(wsk_irp, irp_completion_routine, &irp_complete_event, TRUE, TRUE, TRUE);
status = tcp_dispatch->WskAccept(tcp_socket, 0, NULL, NULL, NULL, NULL, wsk_irp);
memset(&local_addr, 0, sizeof(local_addr));
memset(&remote_addr, 0, sizeof(remote_addr));
status = tcp_dispatch->WskAccept(tcp_socket, 0, NULL, NULL,
(SOCKADDR *)&local_addr, (SOCKADDR *)&remote_addr, wsk_irp);
ok(status == STATUS_PENDING, "Got unexpected status %#x.\n", status);
if (0)
@ -306,6 +309,17 @@ static void test_wsk_listen_socket(void)
if (status == STATUS_SUCCESS && wsk_irp->IoStatus.Status == STATUS_SUCCESS)
{
ok(local_addr.sin_family == AF_INET, "Got unexpected sin_family %u.\n", local_addr.sin_family);
ok(local_addr.sin_port == htons(SERVER_LISTEN_PORT), "Got unexpected sin_port %u.\n",
ntohs(local_addr.sin_port));
ok(local_addr.sin_addr.s_addr == htonl(0x7f000001), "Got unexpected sin_addr %#x.\n",
ntohl(local_addr.sin_addr.s_addr));
ok(remote_addr.sin_family == AF_INET, "Got unexpected sin_family %u.\n", remote_addr.sin_family);
ok(remote_addr.sin_port, "Got zero sin_port.\n");
ok(remote_addr.sin_addr.s_addr == htonl(0x7f000001), "Got unexpected sin_addr %#x.\n",
ntohl(remote_addr.sin_addr.s_addr));
accept_socket = (WSK_SOCKET *)wsk_irp->IoStatus.Information;
accept_dispatch = accept_socket->Dispatch;