diff --git a/dlls/ws2_32/tests/afd.c b/dlls/ws2_32/tests/afd.c index 48b177ee845..97b35e1f3ea 100644 --- a/dlls/ws2_32/tests/afd.c +++ b/dlls/ws2_32/tests/afd.c @@ -31,6 +31,41 @@ #include "wine/afd.h" #include "wine/test.h" +static void tcp_socketpair(SOCKET *src, SOCKET *dst) +{ + SOCKET server = INVALID_SOCKET; + struct sockaddr_in addr; + int len, ret; + + *src = WSASocketW(AF_INET, SOCK_STREAM, IPPROTO_TCP, NULL, 0, WSA_FLAG_OVERLAPPED); + ok(*src != INVALID_SOCKET, "failed to create socket, error %u\n", WSAGetLastError()); + + server = WSASocketW(AF_INET, SOCK_STREAM, IPPROTO_TCP, NULL, 0, WSA_FLAG_OVERLAPPED); + ok(server != INVALID_SOCKET, "failed to create socket, error %u\n", WSAGetLastError()); + + memset(&addr, 0, sizeof(addr)); + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + ret = bind(server, (struct sockaddr *)&addr, sizeof(addr)); + ok(!ret, "failed to bind socket, error %u\n", WSAGetLastError()); + + len = sizeof(addr); + ret = getsockname(server, (struct sockaddr *)&addr, &len); + ok(!ret, "failed to get address, error %u\n", WSAGetLastError()); + + ret = listen(server, 1); + ok(!ret, "failed to listen, error %u\n", WSAGetLastError()); + + ret = connect(*src, (struct sockaddr *)&addr, sizeof(addr)); + ok(!ret, "failed to connect, error %u\n", WSAGetLastError()); + + len = sizeof(addr); + *dst = accept(server, (struct sockaddr *)&addr, &len); + ok(*dst != INVALID_SOCKET, "failed to accept socket, error %u\n", WSAGetLastError()); + + closesocket(server); +} + static void set_blocking(SOCKET s, ULONG blocking) { int ret; @@ -76,6 +111,652 @@ static void test_open_device(void) closesocket(s); } +#define check_poll(a, b, c) check_poll_(__LINE__, a, b, c, FALSE) +#define check_poll_todo(a, b, c) check_poll_(__LINE__, a, b, c, TRUE) +static void check_poll_(int line, SOCKET s, HANDLE event, int expect, BOOL todo) +{ + struct afd_poll_params in_params = {0}, out_params = {0}; + IO_STATUS_BLOCK io; + NTSTATUS ret; + + in_params.timeout = -1000 * 10000; + in_params.count = 1; + in_params.sockets[0].socket = s; + in_params.sockets[0].flags = ~0; + in_params.sockets[0].status = 0xdeadbeef; + + ret = NtDeviceIoControlFile((HANDLE)s, event, NULL, NULL, &io, + IOCTL_AFD_POLL, &in_params, sizeof(in_params), &out_params, sizeof(out_params)); + ok_(__FILE__, line)(!ret, "got %#x\n", ret); + ok_(__FILE__, line)(!io.Status, "got %#x\n", io.Status); + ok_(__FILE__, line)(io.Information == sizeof(out_params), "got %#Ix\n", io.Information); + ok_(__FILE__, line)(out_params.timeout == in_params.timeout, "got timeout %I64d\n", out_params.timeout); + ok_(__FILE__, line)(out_params.count == 1, "got count %u\n", out_params.count); + ok_(__FILE__, line)(out_params.sockets[0].socket == s, "got socket %#Ix\n", out_params.sockets[0].socket); + todo_wine_if (todo) ok_(__FILE__, line)(out_params.sockets[0].flags == expect, "got flags %#x\n", out_params.sockets[0].flags); + ok_(__FILE__, line)(!out_params.sockets[0].status, "got status %#x\n", out_params.sockets[0].status); +} + +static void test_poll(void) +{ + const struct sockaddr_in bind_addr = {.sin_family = AF_INET, .sin_addr.s_addr = htonl(INADDR_LOOPBACK)}; + char in_buffer[offsetof(struct afd_poll_params, sockets[3])]; + char out_buffer[offsetof(struct afd_poll_params, sockets[3])]; + struct afd_poll_params *in_params = (struct afd_poll_params *)in_buffer; + struct afd_poll_params *out_params = (struct afd_poll_params *)out_buffer; + int large_buffer_size = 1024 * 1024; + SOCKET client, server, listener; + struct sockaddr_in addr; + char *large_buffer; + IO_STATUS_BLOCK io; + LARGE_INTEGER now; + ULONG params_size; + HANDLE event; + int ret, len; + + large_buffer = malloc(large_buffer_size); + memset(in_buffer, 0, sizeof(in_buffer)); + memset(out_buffer, 0, sizeof(out_buffer)); + event = CreateEventW(NULL, TRUE, FALSE, NULL); + + listener = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + ret = bind(listener, (const struct sockaddr *)&bind_addr, sizeof(bind_addr)); + ok(!ret, "got error %u\n", WSAGetLastError()); + ret = listen(listener, 1); + ok(!ret, "got error %u\n", WSAGetLastError()); + len = sizeof(addr); + ret = getsockname(listener, (struct sockaddr *)&addr, &len); + ok(!ret, "got error %u\n", WSAGetLastError()); + + params_size = offsetof(struct afd_poll_params, sockets[1]); + in_params->count = 1; + + /* out_size must be at least as large as in_size. */ + + ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size, NULL, 0); + ok(ret == STATUS_INVALID_PARAMETER, "got %#x\n", ret); + + ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io, + IOCTL_AFD_POLL, NULL, 0, out_params, params_size); + ok(ret == STATUS_INVALID_PARAMETER, "got %#x\n", ret); + + ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size, out_params, params_size + 1); + ok(ret == STATUS_INVALID_HANDLE, "got %#x\n", ret); + + ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size + 1, out_params, params_size); + ok(ret == STATUS_INVALID_PARAMETER, "got %#x\n", ret); + + ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size - 1, out_params, params_size - 1); + ok(ret == STATUS_INVALID_PARAMETER, "got %#x\n", ret); + + ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size + 1, out_params, params_size + 1); + ok(ret == STATUS_INVALID_HANDLE, "got %#x\n", ret); + + in_params->count = 0; + ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size, out_params, params_size); + ok(ret == STATUS_INVALID_PARAMETER, "got %#x\n", ret); + + /* Basic semantics of the ioctl. */ + + in_params->timeout = 0; + in_params->count = 1; + in_params->sockets[0].socket = listener; + in_params->sockets[0].flags = ~0; + in_params->sockets[0].status = 0xdeadbeef; + + memset(out_params, 0, params_size); + ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size, out_params, params_size); + ok(!ret, "got %#x\n", ret); + ok(!io.Status, "got %#x\n", io.Status); + ok(io.Information == offsetof(struct afd_poll_params, sockets[0]), "got %#Ix\n", io.Information); + ok(!out_params->timeout, "got timeout %#I64x\n", out_params->timeout); + ok(!out_params->count, "got count %u\n", out_params->count); + ok(!out_params->sockets[0].socket, "got socket %#Ix\n", out_params->sockets[0].socket); + ok(!out_params->sockets[0].flags, "got flags %#x\n", out_params->sockets[0].flags); + ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status); + + NtQuerySystemTime(&now); + in_params->timeout = now.QuadPart; + + ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size, out_params, params_size); + ok(ret == STATUS_PENDING, "got %#x\n", ret); + ret = WaitForSingleObject(event, 100); + ok(!ret, "got %#x\n", ret); + ok(io.Status == STATUS_TIMEOUT, "got %#x\n", io.Status); + ok(io.Information == offsetof(struct afd_poll_params, sockets[0]), "got %#Ix\n", io.Information); + ok(out_params->timeout == now.QuadPart, "got timeout %#I64x\n", out_params->timeout); + ok(!out_params->count, "got count %u\n", out_params->count); + + in_params->timeout = -1000 * 10000; + + ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size, out_params, params_size); + ok(ret == STATUS_PENDING, "got %#x\n", ret); + + client = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + set_blocking(client, FALSE); + ret = connect(client, (struct sockaddr *)&addr, sizeof(addr)); + ok(!ret || WSAGetLastError() == WSAEWOULDBLOCK, "got error %u\n", WSAGetLastError()); + + ret = WaitForSingleObject(event, 100); + ok(!ret, "got %#x\n", ret); + ok(!io.Status, "got %#x\n", io.Status); + ok(io.Information == offsetof(struct afd_poll_params, sockets[1]), "got %#Ix\n", io.Information); + ok(out_params->timeout == -1000 * 10000, "got timeout %#I64x\n", out_params->timeout); + ok(out_params->count == 1, "got count %u\n", out_params->count); + ok(out_params->sockets[0].socket == listener, "got socket %#Ix\n", out_params->sockets[0].socket); + ok(out_params->sockets[0].flags == AFD_POLL_ACCEPT, "got flags %#x\n", out_params->sockets[0].flags); + ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status); + + ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size, out_params, params_size); + ok(!ret, "got %#x\n", ret); + ok(!io.Status, "got %#x\n", io.Status); + ok(io.Information == offsetof(struct afd_poll_params, sockets[1]), "got %#Ix\n", io.Information); + ok(out_params->timeout == -1000 * 10000, "got timeout %#I64x\n", out_params->timeout); + ok(out_params->count == 1, "got count %u\n", out_params->count); + ok(out_params->sockets[0].socket == listener, "got socket %#Ix\n", out_params->sockets[0].socket); + ok(out_params->sockets[0].flags == AFD_POLL_ACCEPT, "got flags %#x\n", out_params->sockets[0].flags); + ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status); + + in_params->timeout = now.QuadPart; + in_params->sockets[0].flags = (~0) & ~AFD_POLL_ACCEPT; + + ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size, out_params, params_size); + ok(ret == STATUS_PENDING, "got %#x\n", ret); + ret = WaitForSingleObject(event, 100); + ok(!ret, "got %#x\n", ret); + ok(io.Status == STATUS_TIMEOUT, "got %#x\n", io.Status); + ok(io.Information == offsetof(struct afd_poll_params, sockets[0]), "got %#Ix\n", io.Information); + ok(!out_params->count, "got count %u\n", out_params->count); + + server = accept(listener, NULL, NULL); + ok(server != -1, "got error %u\n", WSAGetLastError()); + set_blocking(server, FALSE); + + /* Test flags exposed by connected sockets. */ + + check_poll(client, event, AFD_POLL_WRITE | AFD_POLL_CONNECT); + check_poll(server, event, AFD_POLL_WRITE | AFD_POLL_CONNECT); + + /* It is valid to poll on a socket other than the one passed to + * NtDeviceIoControlFile(). */ + + in_params->count = 1; + in_params->sockets[0].socket = server; + in_params->sockets[0].flags = ~0; + + ret = NtDeviceIoControlFile((HANDLE)listener, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size, out_params, params_size); + ok(!ret, "got %#x\n", ret); + ok(!io.Status, "got %#x\n", io.Status); + ok(io.Information == offsetof(struct afd_poll_params, sockets[1]), "got %#Ix\n", io.Information); + ok(out_params->count == 1, "got count %u\n", out_params->count); + ok(out_params->sockets[0].socket == server, "got socket %#Ix\n", out_params->sockets[0].socket); + ok(out_params->sockets[0].flags == (AFD_POLL_WRITE | AFD_POLL_CONNECT), + "got flags %#x\n", out_params->sockets[0].flags); + ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status); + + /* Test sending data. */ + + ret = send(server, "data", 5, 0); + ok(ret == 5, "got %d\n", ret); + + check_poll(client, event, AFD_POLL_WRITE | AFD_POLL_CONNECT | AFD_POLL_READ); + check_poll(server, event, AFD_POLL_WRITE | AFD_POLL_CONNECT); + + while (send(server, large_buffer, large_buffer_size, 0) == large_buffer_size); + + check_poll(client, event, AFD_POLL_WRITE | AFD_POLL_CONNECT | AFD_POLL_READ); + check_poll(server, event, AFD_POLL_CONNECT); + + /* Test sending out-of-band data. */ + + ret = send(client, "a", 1, MSG_OOB); + ok(ret == 1, "got %d\n", ret); + + check_poll(client, event, AFD_POLL_WRITE | AFD_POLL_CONNECT | AFD_POLL_READ); + check_poll(server, event, AFD_POLL_CONNECT | AFD_POLL_OOB); + + ret = recv(server, large_buffer, 1, MSG_OOB); + ok(ret == 1, "got %d\n", ret); + + check_poll(client, event, AFD_POLL_WRITE | AFD_POLL_CONNECT | AFD_POLL_READ); + check_poll(server, event, AFD_POLL_CONNECT); + + ret = 1; + ret = setsockopt(server, SOL_SOCKET, SO_OOBINLINE, (char *)&ret, sizeof(ret)); + ok(!ret, "got error %u\n", WSAGetLastError()); + + ret = send(client, "a", 1, MSG_OOB); + ok(ret == 1, "got %d\n", ret); + + check_poll(client, event, AFD_POLL_WRITE | AFD_POLL_CONNECT | AFD_POLL_READ); + check_poll(server, event, AFD_POLL_CONNECT | AFD_POLL_READ); + + closesocket(client); + closesocket(server); + + /* Test shutdown. */ + + client = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + ret = connect(client, (struct sockaddr *)&addr, sizeof(addr)); + ok(!ret, "got error %u\n", WSAGetLastError()); + server = accept(listener, NULL, NULL); + ok(server != -1, "got error %u\n", WSAGetLastError()); + + ret = shutdown(client, SD_RECEIVE); + ok(!ret, "got error %u\n", WSAGetLastError()); + + check_poll(client, event, AFD_POLL_WRITE | AFD_POLL_CONNECT); + check_poll(server, event, AFD_POLL_WRITE | AFD_POLL_CONNECT); + + ret = shutdown(client, SD_SEND); + ok(!ret, "got error %u\n", WSAGetLastError()); + + check_poll(client, event, AFD_POLL_WRITE | AFD_POLL_CONNECT); + check_poll(server, event, AFD_POLL_WRITE | AFD_POLL_CONNECT | AFD_POLL_HUP); + + closesocket(client); + closesocket(server); + + /* Test shutdown with data in the pipe. */ + + client = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + ret = connect(client, (struct sockaddr *)&addr, sizeof(addr)); + ok(!ret, "got error %u\n", WSAGetLastError()); + server = accept(listener, NULL, NULL); + ok(server != -1, "got error %u\n", WSAGetLastError()); + + ret = send(client, "data", 5, 0); + ok(ret == 5, "got %d\n", ret); + + check_poll(client, event, AFD_POLL_WRITE | AFD_POLL_CONNECT); + check_poll(server, event, AFD_POLL_WRITE | AFD_POLL_CONNECT | AFD_POLL_READ); + + ret = shutdown(client, SD_SEND); + ok(!ret, "got error %u\n", WSAGetLastError()); + + check_poll(client, event, AFD_POLL_WRITE | AFD_POLL_CONNECT); + check_poll_todo(server, event, AFD_POLL_WRITE | AFD_POLL_CONNECT | AFD_POLL_READ | AFD_POLL_HUP); + + /* Test closing a socket while polling on it. Note that AFD_POLL_CLOSE + * is always returned, regardless of whether it's polled for. */ + + in_params->timeout = -1000 * 10000; + in_params->count = 1; + in_params->sockets[0].socket = client; + in_params->sockets[0].flags = 0; + + ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size, out_params, params_size); + ok(ret == STATUS_PENDING, "got %#x\n", ret); + + closesocket(client); + + ret = WaitForSingleObject(event, 100); + ok(!ret, "got %#x\n", ret); + ok(!io.Status, "got %#x\n", io.Status); + ok(io.Information == offsetof(struct afd_poll_params, sockets[1]), "got %#Ix\n", io.Information); + ok(out_params->count == 1, "got count %u\n", out_params->count); + ok(out_params->sockets[0].socket == client, "got socket %#Ix\n", out_params->sockets[0].socket); + ok(out_params->sockets[0].flags == AFD_POLL_CLOSE, + "got flags %#x\n", out_params->sockets[0].flags); + ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status); + + closesocket(server); + + /* Test a failed connection. + * + * The following poll works even where the equivalent WSAPoll() call fails. + * However, it can take over 2 seconds to complete on the testbot. */ + + if (winetest_interactive) + { + client = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + set_blocking(client, FALSE); + + in_params->timeout = -10000 * 10000; + in_params->count = 1; + in_params->sockets[0].socket = client; + in_params->sockets[0].flags = ~0; + + ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size, out_params, params_size); + ok(ret == STATUS_PENDING, "got %#x\n", ret); + + addr.sin_port = 255; + ret = connect(client, (struct sockaddr *)&addr, sizeof(addr)); + ok(!ret || WSAGetLastError() == WSAEWOULDBLOCK, "got error %u\n", WSAGetLastError()); + + ret = WaitForSingleObject(event, 10000); + ok(!ret, "got %#x\n", ret); + ok(!io.Status, "got %#x\n", io.Status); + ok(io.Information == offsetof(struct afd_poll_params, sockets[1]), "got %#Ix\n", io.Information); + ok(out_params->count == 1, "got count %u\n", out_params->count); + ok(out_params->sockets[0].socket == client, "got socket %#Ix\n", out_params->sockets[0].socket); + ok(out_params->sockets[0].flags == AFD_POLL_CONNECT_ERR, "got flags %#x\n", out_params->sockets[0].flags); + ok(out_params->sockets[0].status == STATUS_CONNECTION_REFUSED, "got status %#x\n", out_params->sockets[0].status); + + closesocket(client); + } + + /* Test supplying multiple handles to the ioctl. */ + + len = sizeof(addr); + ret = getsockname(listener, (struct sockaddr *)&addr, &len); + ok(!ret, "got error %u\n", WSAGetLastError()); + + client = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + ret = connect(client, (struct sockaddr *)&addr, sizeof(addr)); + ok(!ret, "got error %u\n", WSAGetLastError()); + server = accept(listener, NULL, NULL); + ok(server != -1, "got error %u\n", WSAGetLastError()); + + in_params->count = 2; + in_params->sockets[0].socket = client; + in_params->sockets[0].flags = AFD_POLL_READ; + in_params->sockets[1].socket = server; + in_params->sockets[1].flags = AFD_POLL_READ; + + ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size, out_params, params_size); + ok(ret == STATUS_INVALID_PARAMETER, "got %#x\n", ret); + + params_size = offsetof(struct afd_poll_params, sockets[2]); + + ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size, out_params, params_size); + ok(ret == STATUS_PENDING, "got %#x\n", ret); + + ret = send(client, "data", 5, 0); + ok(ret == 5, "got %d\n", ret); + + ret = WaitForSingleObject(event, 100); + ok(!ret, "got %#x\n", ret); + ok(!io.Status, "got %#x\n", io.Status); + ok(io.Information == offsetof(struct afd_poll_params, sockets[1]), "got %#Ix\n", io.Information); + ok(out_params->count == 1, "got count %u\n", out_params->count); + ok(out_params->sockets[0].socket == server, "got socket %#Ix\n", out_params->sockets[0].socket); + ok(out_params->sockets[0].flags == AFD_POLL_READ, "got flags %#x\n", out_params->sockets[0].flags); + ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status); + + in_params->count = 2; + in_params->sockets[0].socket = client; + in_params->sockets[0].flags = AFD_POLL_READ | AFD_POLL_WRITE; + in_params->sockets[1].socket = server; + in_params->sockets[1].flags = AFD_POLL_READ | AFD_POLL_WRITE; + + ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size, out_params, params_size); + ok(!ret, "got %#x\n", ret); + ok(!io.Status, "got %#x\n", io.Status); + ok(io.Information == offsetof(struct afd_poll_params, sockets[2]), "got %#Ix\n", io.Information); + ok(out_params->count == 2, "got count %u\n", out_params->count); + ok(out_params->sockets[0].socket == client, "got socket %#Ix\n", out_params->sockets[0].socket); + ok(out_params->sockets[0].flags == AFD_POLL_WRITE, "got flags %#x\n", out_params->sockets[0].flags); + ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status); + ok(out_params->sockets[1].socket == server, "got socket %#Ix\n", out_params->sockets[1].socket); + ok(out_params->sockets[1].flags == (AFD_POLL_READ | AFD_POLL_WRITE), + "got flags %#x\n", out_params->sockets[1].flags); + ok(!out_params->sockets[1].status, "got status %#x\n", out_params->sockets[1].status); + + in_params->count = 2; + in_params->sockets[0].socket = client; + in_params->sockets[0].flags = AFD_POLL_READ | AFD_POLL_WRITE; + in_params->sockets[1].socket = server; + in_params->sockets[1].flags = AFD_POLL_READ | AFD_POLL_WRITE; + + ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size, out_params, params_size); + ok(!ret, "got %#x\n", ret); + ok(!io.Status, "got %#x\n", io.Status); + ok(io.Information == offsetof(struct afd_poll_params, sockets[2]), "got %#Ix\n", io.Information); + ok(out_params->count == 2, "got count %u\n", out_params->count); + ok(out_params->sockets[0].socket == client, "got socket %#Ix\n", out_params->sockets[0].socket); + ok(out_params->sockets[0].flags == AFD_POLL_WRITE, "got flags %#x\n", out_params->sockets[0].flags); + ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status); + ok(out_params->sockets[1].socket == server, "got socket %#Ix\n", out_params->sockets[1].socket); + ok(out_params->sockets[1].flags == (AFD_POLL_READ | AFD_POLL_WRITE), + "got flags %#x\n", out_params->sockets[1].flags); + ok(!out_params->sockets[1].status, "got status %#x\n", out_params->sockets[1].status); + + /* Close a socket while polling on another. */ + + in_params->timeout = -100 * 10000; + in_params->count = 1; + in_params->sockets[0].socket = client; + in_params->sockets[0].flags = AFD_POLL_READ; + params_size = offsetof(struct afd_poll_params, sockets[1]); + + ret = NtDeviceIoControlFile((HANDLE)server, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size, out_params, params_size); + ok(ret == STATUS_PENDING, "got %#x\n", ret); + + closesocket(server); + + ret = WaitForSingleObject(event, 1000); + ok(!ret, "got %#x\n", ret); + todo_wine ok(io.Status == STATUS_TIMEOUT, "got %#x\n", io.Status); + todo_wine ok(io.Information == offsetof(struct afd_poll_params, sockets[0]), "got %#Ix\n", io.Information); + todo_wine ok(!out_params->count, "got count %u\n", out_params->count); + + closesocket(client); + + closesocket(listener); + + /* Test UDP sockets. */ + + client = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); + server = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); + + check_poll(client, event, AFD_POLL_WRITE); + check_poll(server, event, AFD_POLL_WRITE); + + ret = bind(client, (const struct sockaddr *)&bind_addr, sizeof(bind_addr)); + ok(!ret, "got error %u\n", WSAGetLastError()); + len = sizeof(addr); + ret = getsockname(listener, (struct sockaddr *)&addr, &len); + ok(!ret, "got error %u\n", WSAGetLastError()); + + check_poll(client, event, AFD_POLL_WRITE); + check_poll(server, event, AFD_POLL_WRITE); + + in_params->timeout = -1000 * 10000; + in_params->count = 1; + in_params->sockets[0].socket = client; + in_params->sockets[0].flags = (~0) & ~AFD_POLL_WRITE; + params_size = offsetof(struct afd_poll_params, sockets[1]); + + ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size, out_params, params_size); + ok(ret == STATUS_PENDING, "got %#x\n", ret); + + ret = sendto(server, "data", 5, 0, (struct sockaddr *)&addr, sizeof(addr)); + ok(ret == 5, "got %d\n", ret); + + ret = WaitForSingleObject(event, 100); + ok(!ret, "got %#x\n", ret); + ok(!io.Status, "got %#x\n", io.Status); + ok(io.Information == offsetof(struct afd_poll_params, sockets[1]), "got %#Ix\n", io.Information); + ok(out_params->count == 1, "got count %u\n", out_params->count); + ok(out_params->sockets[0].socket == client, "got socket %#Ix\n", out_params->sockets[0].socket); + ok(out_params->sockets[0].flags == AFD_POLL_READ, "got flags %#x\n", out_params->sockets[0].flags); + ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status); + + closesocket(client); + closesocket(server); + + /* Passing any invalid sockets yields STATUS_INVALID_HANDLE. + * + * Note however that WSAPoll() happily accepts invalid sockets. It seems + * user-side cached data is used: closing a handle with CloseHandle() before + * passing it to WSAPoll() yields ENOTSOCK. */ + + tcp_socketpair(&client, &server); + + in_params->count = 2; + in_params->sockets[0].socket = 0xabacab; + in_params->sockets[0].flags = AFD_POLL_READ | AFD_POLL_WRITE; + in_params->sockets[1].socket = client; + in_params->sockets[1].flags = AFD_POLL_READ | AFD_POLL_WRITE; + params_size = offsetof(struct afd_poll_params, sockets[2]); + + memset(&io, 0, sizeof(io)); + ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size, out_params, params_size); + ok(ret == STATUS_INVALID_HANDLE, "got %#x\n", ret); + todo_wine ok(!io.Status, "got %#x\n", io.Status); + ok(!io.Information, "got %#Ix\n", io.Information); + + /* Test passing the same handle twice. */ + + in_params->count = 3; + in_params->sockets[0].socket = client; + in_params->sockets[0].flags = AFD_POLL_READ | AFD_POLL_WRITE; + in_params->sockets[1].socket = client; + in_params->sockets[1].flags = AFD_POLL_READ | AFD_POLL_WRITE; + in_params->sockets[2].socket = client; + in_params->sockets[2].flags = AFD_POLL_READ | AFD_POLL_WRITE | AFD_POLL_CONNECT; + params_size = offsetof(struct afd_poll_params, sockets[3]); + + ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size, out_params, params_size); + ok(!ret, "got %#x\n", ret); + ok(!io.Status, "got %#x\n", io.Status); + ok(io.Information == offsetof(struct afd_poll_params, sockets[3]), "got %#Ix\n", io.Information); + ok(out_params->count == 3, "got count %u\n", out_params->count); + ok(out_params->sockets[0].socket == client, "got socket %#Ix\n", out_params->sockets[0].socket); + ok(out_params->sockets[0].flags == AFD_POLL_WRITE, "got flags %#x\n", out_params->sockets[0].flags); + ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status); + ok(out_params->sockets[1].socket == client, "got socket %#Ix\n", out_params->sockets[1].socket); + ok(out_params->sockets[1].flags == AFD_POLL_WRITE, "got flags %#x\n", out_params->sockets[1].flags); + ok(!out_params->sockets[1].status, "got status %#x\n", out_params->sockets[1].status); + ok(out_params->sockets[2].socket == client, "got socket %#Ix\n", out_params->sockets[2].socket); + ok(out_params->sockets[2].flags == (AFD_POLL_WRITE | AFD_POLL_CONNECT), + "got flags %#x\n", out_params->sockets[2].flags); + ok(!out_params->sockets[2].status, "got status %#x\n", out_params->sockets[2].status); + + in_params->count = 2; + in_params->sockets[0].socket = client; + in_params->sockets[0].flags = AFD_POLL_READ; + in_params->sockets[1].socket = client; + in_params->sockets[1].flags = AFD_POLL_READ; + params_size = offsetof(struct afd_poll_params, sockets[2]); + + ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size, out_params, params_size); + ok(ret == STATUS_PENDING, "got %#x\n", ret); + + ret = send(server, "data", 5, 0); + ok(ret == 5, "got %d\n", ret); + + ret = WaitForSingleObject(event, 100); + ok(!ret, "got %#x\n", ret); + ok(!io.Status, "got %#x\n", io.Status); + ok(io.Information == offsetof(struct afd_poll_params, sockets[1]), "got %#Ix\n", io.Information); + ok(out_params->count == 1, "got count %u\n", out_params->count); + ok(out_params->sockets[0].socket == client, "got socket %#Ix\n", out_params->sockets[0].socket); + ok(out_params->sockets[0].flags == AFD_POLL_READ, "got flags %#x\n", out_params->sockets[0].flags); + ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status); + + closesocket(client); + closesocket(server); + + CloseHandle(event); + free(large_buffer); +} + +static void test_poll_completion_port(void) +{ + struct afd_poll_params params = {0}; + LARGE_INTEGER zero = {{0}}; + SOCKET client, server; + ULONG_PTR key, value; + IO_STATUS_BLOCK io; + HANDLE event, port; + int ret; + + event = CreateEventW(NULL, TRUE, FALSE, NULL); + tcp_socketpair(&client, &server); + port = CreateIoCompletionPort((HANDLE)client, NULL, 0, 0); + + params.timeout = -100 * 10000; + params.count = 1; + params.sockets[0].socket = client; + params.sockets[0].flags = AFD_POLL_WRITE; + params.sockets[0].status = 0xdeadbeef; + + ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io, + IOCTL_AFD_POLL, ¶ms, sizeof(params), ¶ms, sizeof(params)); + ok(!ret, "got %#x\n", ret); + + ret = NtRemoveIoCompletion(port, &key, &value, &io, &zero); + ok(ret == STATUS_TIMEOUT, "got %#x\n", ret); + + ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, (void *)0xdeadbeef, &io, + IOCTL_AFD_POLL, ¶ms, sizeof(params), ¶ms, sizeof(params)); + ok(!ret, "got %#x\n", ret); + + ret = NtRemoveIoCompletion(port, &key, &value, &io, &zero); + ok(!ret, "got %#x\n", ret); + ok(!key, "got key %#Ix\n", key); + ok(value == 0xdeadbeef, "got value %#Ix\n", value); + + params.timeout = 0; + params.count = 1; + params.sockets[0].socket = client; + params.sockets[0].flags = AFD_POLL_READ; + params.sockets[0].status = 0xdeadbeef; + + ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, (void *)0xdeadbeef, &io, + IOCTL_AFD_POLL, ¶ms, sizeof(params), ¶ms, sizeof(params)); + ok(!ret, "got %#x\n", ret); + + ret = NtRemoveIoCompletion(port, &key, &value, &io, &zero); + ok(!ret, "got %#x\n", ret); + ok(!key, "got key %#Ix\n", key); + ok(value == 0xdeadbeef, "got value %#Ix\n", value); + + /* Close a socket while polling on another. */ + + params.timeout = -100 * 10000; + params.count = 1; + params.sockets[0].socket = server; + params.sockets[0].flags = AFD_POLL_READ; + params.sockets[0].status = 0xdeadbeef; + + ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, (void *)0xdeadbeef, &io, + IOCTL_AFD_POLL, ¶ms, sizeof(params), ¶ms, sizeof(params)); + ok(ret == STATUS_PENDING, "got %#x\n", ret); + + closesocket(client); + + ret = WaitForSingleObject(event, 1000); + ok(!ret, "got %#x\n", ret); + todo_wine ok(io.Status == STATUS_TIMEOUT, "got %#x\n", io.Status); + todo_wine ok(io.Information == offsetof(struct afd_poll_params, sockets[0]), "got %#Ix\n", io.Information); + todo_wine ok(!params.count, "got count %u\n", params.count); + + ret = NtRemoveIoCompletion(port, &key, &value, &io, &zero); + ok(!ret, "got %#x\n", ret); + ok(!key, "got key %#Ix\n", key); + ok(value == 0xdeadbeef, "got value %#Ix\n", value); + + CloseHandle(port); + closesocket(server); + CloseHandle(event); +} + static void test_recv(void) { const struct sockaddr_in bind_addr = {.sin_family = AF_INET, .sin_addr.s_addr = htonl(INADDR_LOOPBACK)}; @@ -435,6 +1116,8 @@ START_TEST(afd) WSAStartup(MAKEWORD(2, 2), &data); test_open_device(); + test_poll(); + test_poll_completion_port(); test_recv(); WSACleanup();