diff --git a/dlls/netio.sys/netio.c b/dlls/netio.sys/netio.c index 09f09f1e988..b0ca6801799 100644 --- a/dlls/netio.sys/netio.c +++ b/dlls/netio.sys/netio.c @@ -46,6 +46,7 @@ struct _WSK_CLIENT struct listen_socket_callback_context { + SOCKADDR *local_address; SOCKADDR *remote_address; const void *client_dispatch; void *client_context; @@ -53,15 +54,6 @@ struct listen_socket_callback_context SOCKET acceptor; }; -struct connect_socket_callback_context -{ - struct wsk_socket_internal *socket; - SOCKADDR *remote_address; - const void *client_dispatch; - void *client_context; - IRP *pending_irp; -}; - #define MAX_PENDING_IO 10 struct wsk_pending_io @@ -96,6 +88,7 @@ struct wsk_socket_internal }; static LPFN_ACCEPTEX pAcceptEx; +static LPFN_GETACCEPTEXSOCKADDRS pGetAcceptExSockaddrs; static LPFN_CONNECTEX pConnectEx; static const WSK_PROVIDER_CONNECTION_DISPATCH wsk_provider_connection_dispatch; @@ -319,6 +312,8 @@ static void create_accept_socket(struct wsk_socket_internal *socket, struct wsk_ { struct listen_socket_callback_context *context = &socket->callback_context.listen_socket_callback_context; + INT local_address_len, remote_address_len; + SOCKADDR *local_address, *remote_address; struct wsk_socket_internal *accept_socket; if (!(accept_socket = heap_alloc_zero(sizeof(*accept_socket)))) @@ -338,7 +333,17 @@ static void create_accept_socket(struct wsk_socket_internal *socket, struct wsk_ accept_socket->protocol = socket->protocol; accept_socket->flags = WSK_FLAG_CONNECTION_SOCKET; socket_init(accept_socket); - /* TODO: fill local and remote addresses. */ + + pGetAcceptExSockaddrs(context->addr_buffer, 0, sizeof(SOCKADDR) + 16, sizeof(SOCKADDR) + 16, + &local_address, &local_address_len, &remote_address, &remote_address_len); + + if (context->local_address) + memcpy(context->local_address, local_address, + min(sizeof(*context->local_address), local_address_len)); + + if (context->remote_address) + memcpy(context->remote_address, remote_address, + min(sizeof(*context->remote_address), remote_address_len)); dispatch_pending_io(io, STATUS_SUCCESS, (ULONG_PTR)&accept_socket->wsk_socket); } @@ -373,6 +378,7 @@ static void WINAPI accept_callback(TP_CALLBACK_INSTANCE *instance, void *socket_ static BOOL WINAPI init_accept_functions(INIT_ONCE *once, void *param, void **context) { + GUID get_acceptex_guid = WSAID_GETACCEPTEXSOCKADDRS; GUID acceptex_guid = WSAID_ACCEPTEX; SOCKET s = (SOCKET)param; DWORD size; @@ -383,6 +389,14 @@ static BOOL WINAPI init_accept_functions(INIT_ONCE *once, void *param, void **co ERR("Could not get AcceptEx address, error %u.\n", WSAGetLastError()); return FALSE; } + + if (WSAIoctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, &get_acceptex_guid, sizeof(get_acceptex_guid), + &pGetAcceptExSockaddrs, sizeof(pGetAcceptExSockaddrs), &size, NULL, NULL)) + { + ERR("Could not get AcceptEx address, error %u.\n", WSAGetLastError()); + return FALSE; + } + return TRUE; } @@ -430,6 +444,7 @@ static NTSTATUS WINAPI wsk_accept(WSK_SOCKET *listen_socket, ULONG flags, void * return STATUS_PENDING; } + context->local_address = local_address; context->remote_address = remote_address; context->client_dispatch = accept_socket_dispatch; context->client_context = accept_socket_context; diff --git a/dlls/ntoskrnl.exe/tests/driver4.c b/dlls/ntoskrnl.exe/tests/driver4.c index dad1d6a04fd..48de151f31f 100644 --- a/dlls/ntoskrnl.exe/tests/driver4.c +++ b/dlls/ntoskrnl.exe/tests/driver4.c @@ -177,10 +177,10 @@ static void test_wsk_listen_socket(void) static const WSK_CLIENT_LISTEN_DISPATCH client_listen_dispatch; const WSK_PROVIDER_CONNECTION_DISPATCH *accept_dispatch; WSK_SOCKET *tcp_socket, *udp_socket, *accept_socket; + struct sockaddr_in addr, local_addr, remote_addr; struct socket_context context; WSK_BUF wsk_buf1, wsk_buf2; void *buffer1, *buffer2; - struct sockaddr_in addr; LARGE_INTEGER timeout; MDL *mdl1, *mdl2; NTSTATUS status; @@ -287,7 +287,10 @@ static void test_wsk_listen_socket(void) IoReuseIrp(wsk_irp, STATUS_UNSUCCESSFUL); IoSetCompletionRoutine(wsk_irp, irp_completion_routine, &irp_complete_event, TRUE, TRUE, TRUE); - status = tcp_dispatch->WskAccept(tcp_socket, 0, NULL, NULL, NULL, NULL, wsk_irp); + memset(&local_addr, 0, sizeof(local_addr)); + memset(&remote_addr, 0, sizeof(remote_addr)); + status = tcp_dispatch->WskAccept(tcp_socket, 0, NULL, NULL, + (SOCKADDR *)&local_addr, (SOCKADDR *)&remote_addr, wsk_irp); ok(status == STATUS_PENDING, "Got unexpected status %#x.\n", status); if (0) @@ -306,6 +309,17 @@ static void test_wsk_listen_socket(void) if (status == STATUS_SUCCESS && wsk_irp->IoStatus.Status == STATUS_SUCCESS) { + ok(local_addr.sin_family == AF_INET, "Got unexpected sin_family %u.\n", local_addr.sin_family); + ok(local_addr.sin_port == htons(SERVER_LISTEN_PORT), "Got unexpected sin_port %u.\n", + ntohs(local_addr.sin_port)); + ok(local_addr.sin_addr.s_addr == htonl(0x7f000001), "Got unexpected sin_addr %#x.\n", + ntohl(local_addr.sin_addr.s_addr)); + + ok(remote_addr.sin_family == AF_INET, "Got unexpected sin_family %u.\n", remote_addr.sin_family); + ok(remote_addr.sin_port, "Got zero sin_port.\n"); + ok(remote_addr.sin_addr.s_addr == htonl(0x7f000001), "Got unexpected sin_addr %#x.\n", + ntohl(remote_addr.sin_addr.s_addr)); + accept_socket = (WSK_SOCKET *)wsk_irp->IoStatus.Information; accept_dispatch = accept_socket->Dispatch;