diff --git a/dlls/ws2_32/tests/afd.c b/dlls/ws2_32/tests/afd.c index 184c7e0725a..cf27b4003c5 100644 --- a/dlls/ws2_32/tests/afd.c +++ b/dlls/ws2_32/tests/afd.c @@ -76,6 +76,21 @@ static void set_blocking(SOCKET s, ULONG blocking) ok(!ret, "got error %u\n", WSAGetLastError()); } +/* Set the linger timeout to zero and close the socket. This will trigger an + * RST on the connection on Windows as well as on Unix systems. */ +static void close_with_rst(SOCKET s) +{ + static const struct linger linger = {.l_onoff = 1}; + int ret; + + SetLastError(0xdeadbeef); + ret = setsockopt(s, SOL_SOCKET, SO_LINGER, (const char *)&linger, sizeof(linger)); + ok(!ret, "got %d\n", ret); + ok(!GetLastError(), "got error %lu\n", GetLastError()); + + closesocket(s); +} + static void test_open_device(void) { OBJECT_BASIC_INFORMATION info; @@ -142,7 +157,8 @@ static void check_poll_(int line, SOCKET s, HANDLE event, int mask, int expect, ok_(__FILE__, line)(out_params.count == 1, "got count %u\n", out_params.count); ok_(__FILE__, line)(out_params.sockets[0].socket == s, "got socket %#Ix\n", out_params.sockets[0].socket); todo_wine_if (todo) ok_(__FILE__, line)(out_params.sockets[0].flags == expect, "got flags %#x\n", out_params.sockets[0].flags); - ok_(__FILE__, line)(!out_params.sockets[0].status, "got status %#x\n", out_params.sockets[0].status); + todo_wine_if (expect & AFD_POLL_RESET) + ok_(__FILE__, line)(!out_params.sockets[0].status, "got status %#x\n", out_params.sockets[0].status); } static void test_poll(void) @@ -1311,6 +1327,50 @@ static void test_poll_completion_port(void) CloseHandle(event); } +static void test_poll_reset(void) +{ + char in_buffer[offsetof(struct afd_poll_params, sockets[3])]; + char out_buffer[offsetof(struct afd_poll_params, sockets[3])]; + struct afd_poll_params *in_params = (struct afd_poll_params *)in_buffer; + struct afd_poll_params *out_params = (struct afd_poll_params *)out_buffer; + SOCKET client, server; + IO_STATUS_BLOCK io; + ULONG params_size; + HANDLE event; + int ret; + + memset(in_buffer, 0, sizeof(in_buffer)); + memset(out_buffer, 0, sizeof(out_buffer)); + event = CreateEventW(NULL, TRUE, FALSE, NULL); + tcp_socketpair(&client, &server); + + in_params->timeout = -1000 * 10000; + in_params->count = 1; + in_params->sockets[0].socket = client; + in_params->sockets[0].flags = ~(AFD_POLL_WRITE | AFD_POLL_CONNECT); + params_size = offsetof(struct afd_poll_params, sockets[1]); + + ret = NtDeviceIoControlFile((HANDLE)client, event, NULL, NULL, &io, + IOCTL_AFD_POLL, in_params, params_size, out_params, params_size); + ok(ret == STATUS_PENDING, "got %#x\n", ret); + + close_with_rst(server); + + ret = WaitForSingleObject(event, 100); + ok(!ret, "got %#x\n", ret); + ok(!io.Status, "got %#lx\n", io.Status); + ok(io.Information == offsetof(struct afd_poll_params, sockets[1]), "got %#Ix\n", io.Information); + ok(out_params->count == 1, "got count %u\n", out_params->count); + ok(out_params->sockets[0].socket == client, "got socket %#Ix\n", out_params->sockets[0].socket); + todo_wine ok(out_params->sockets[0].flags == AFD_POLL_RESET, "got flags %#x\n", out_params->sockets[0].flags); + ok(!out_params->sockets[0].status, "got status %#x\n", out_params->sockets[0].status); + + check_poll_todo(client, event, AFD_POLL_WRITE | AFD_POLL_CONNECT | AFD_POLL_RESET); + + closesocket(client); + CloseHandle(event); +} + static void test_recv(void) { const struct sockaddr_in bind_addr = {.sin_family = AF_INET, .sin_addr.s_addr = htonl(INADDR_LOOPBACK)}; @@ -1914,6 +1974,56 @@ static void test_get_events(void) CloseHandle(event); } +static void test_get_events_reset(void) +{ + struct afd_get_events_params params; + SOCKET client, server; + IO_STATUS_BLOCK io; + unsigned int i; + HANDLE event; + int ret; + + event = CreateEventW(NULL, TRUE, FALSE, NULL); + + tcp_socketpair(&client, &server); + + ret = WSAEventSelect(client, event, FD_ACCEPT | FD_CONNECT | FD_CLOSE | FD_OOB | FD_READ | FD_WRITE); + ok(!ret, "got error %lu\n", GetLastError()); + + close_with_rst(server); + + memset(¶ms, 0xcc, sizeof(params)); + memset(&io, 0xcc, sizeof(io)); + ret = NtDeviceIoControlFile((HANDLE)client, NULL, NULL, NULL, &io, + IOCTL_AFD_GET_EVENTS, NULL, 0, ¶ms, sizeof(params)); + ok(!ret, "got %#x\n", ret); + todo_wine ok(params.flags == (AFD_POLL_RESET | AFD_POLL_CONNECT | AFD_POLL_WRITE), "got flags %#x\n", params.flags); + for (i = 0; i < ARRAY_SIZE(params.status); ++i) + ok(!params.status[i], "got status[%u] %#x\n", i, params.status[i]); + + closesocket(client); + + tcp_socketpair(&client, &server); + + ret = WSAEventSelect(server, event, FD_ACCEPT | FD_CONNECT | FD_CLOSE | FD_OOB | FD_READ | FD_WRITE); + ok(!ret, "got error %lu\n", GetLastError()); + + close_with_rst(client); + + memset(¶ms, 0xcc, sizeof(params)); + memset(&io, 0xcc, sizeof(io)); + ret = NtDeviceIoControlFile((HANDLE)server, NULL, NULL, NULL, &io, + IOCTL_AFD_GET_EVENTS, NULL, 0, ¶ms, sizeof(params)); + ok(!ret, "got %#x\n", ret); + todo_wine ok(params.flags == (AFD_POLL_RESET | AFD_POLL_WRITE), "got flags %#x\n", params.flags); + for (i = 0; i < ARRAY_SIZE(params.status); ++i) + ok(!params.status[i], "got status[%u] %#x\n", i, params.status[i]); + + closesocket(server); + + CloseHandle(event); +} + static void test_bind(void) { const struct sockaddr_in6 bind_addr6 = {.sin6_family = AF_INET6, .sin6_addr.s6_words = {0, 0, 0, 0, 0, 0, 0, htons(1)}}; @@ -2255,9 +2365,11 @@ START_TEST(afd) test_poll(); test_poll_exclusive(); test_poll_completion_port(); + test_poll_reset(); test_recv(); test_event_select(); test_get_events(); + test_get_events_reset(); test_bind(); test_getsockname(); diff --git a/dlls/ws2_32/tests/sock.c b/dlls/ws2_32/tests/sock.c index 7bc2b2474ca..17719bfaacd 100644 --- a/dlls/ws2_32/tests/sock.c +++ b/dlls/ws2_32/tests/sock.c @@ -202,6 +202,21 @@ static void tcp_socketpair(SOCKET *src, SOCKET *dst) tcp_socketpair_flags(src, dst, WSA_FLAG_OVERLAPPED); } +/* Set the linger timeout to zero and close the socket. This will trigger an + * RST on the connection on Windows as well as on Unix systems. */ +static void close_with_rst(SOCKET s) +{ + static const struct linger linger = {.l_onoff = 1}; + int ret; + + SetLastError(0xdeadbeef); + ret = setsockopt(s, SOL_SOCKET, SO_LINGER, (const char *)&linger, sizeof(linger)); + ok(!ret, "got %d\n", ret); + ok(!GetLastError(), "got error %lu\n", GetLastError()); + + closesocket(s); +} + #define check_poll(a, b) check_poll_(__LINE__, a, POLLRDNORM | POLLRDBAND | POLLWRNORM, b, FALSE) #define check_poll_todo(a, b) check_poll_(__LINE__, a, POLLRDNORM | POLLRDBAND | POLLWRNORM, b, TRUE) #define check_poll_mask(a, b, c) check_poll_(__LINE__, a, b, c, FALSE) @@ -5302,7 +5317,9 @@ static void check_events_(int line, struct event_test_ctx *ctx, else { WSANETWORKEVENTS events; + unsigned int i; + memset(&events, 0xcc, sizeof(events)); ret = WaitForSingleObject(ctx->event, timeout); if (flag1 | flag2) todo_wine_if (todo_event && ret) ok_(__FILE__, line)(!ret, "event wait timed out\n"); @@ -5311,7 +5328,16 @@ static void check_events_(int line, struct event_test_ctx *ctx, ret = WSAEnumNetworkEvents(ctx->socket, ctx->event, &events); ok_(__FILE__, line)(!ret, "failed to get events, error %u\n", WSAGetLastError()); todo_wine_if (todo_event) - ok_(__FILE__, line)(events.lNetworkEvents == (flag1 | flag2), "got events %#lx\n", events.lNetworkEvents); + ok_(__FILE__, line)(events.lNetworkEvents == LOWORD(flag1 | flag2), "got events %#lx\n", events.lNetworkEvents); + for (i = 0; i < ARRAY_SIZE(events.iErrorCode); ++i) + { + if ((1u << i) == LOWORD(flag1) && (events.lNetworkEvents & LOWORD(flag1))) + todo_wine_if (HIWORD(flag1)) ok_(__FILE__, line)(events.iErrorCode[i] == HIWORD(flag1), + "got error code %d for event %#x\n", events.iErrorCode[i], 1u << i); + if ((1u << i) == LOWORD(flag2) && (events.lNetworkEvents & LOWORD(flag2))) + ok_(__FILE__, line)(events.iErrorCode[i] == HIWORD(flag2), + "got error code %d for event %#x\n", events.iErrorCode[i], 1u << i); + } } } @@ -6114,6 +6140,28 @@ static void test_close_events(struct event_test_ctx *ctx) check_events(ctx, FD_CLOSE, 0, 200); closesocket(server); + + /* Trigger RST. */ + + tcp_socketpair(&client, &server); + + select_events(ctx, server, FD_ACCEPT | FD_CLOSE | FD_CONNECT | FD_OOB | FD_READ); + + close_with_rst(client); + + check_events_todo_msg(ctx, MAKELONG(FD_CLOSE, WSAECONNABORTED), 0, 200); + check_events(ctx, 0, 0, 0); + select_events(ctx, server, FD_ACCEPT | FD_CLOSE | FD_CONNECT | FD_OOB | FD_READ); + if (ctx->is_message) + check_events_todo(ctx, MAKELONG(FD_CLOSE, WSAECONNABORTED), 0, 200); + check_events(ctx, 0, 0, 0); + select_events(ctx, server, 0); + select_events(ctx, server, FD_ACCEPT | FD_CLOSE | FD_CONNECT | FD_OOB | FD_READ); + if (ctx->is_message) + check_events_todo(ctx, MAKELONG(FD_CLOSE, WSAECONNABORTED), 0, 200); + check_events(ctx, 0, 0, 0); + + closesocket(server); } static void test_events(void) @@ -6552,7 +6600,6 @@ static void test_WSARecv(void) WSABUF bufs[2]; WSAOVERLAPPED ov; DWORD bytesReturned, flags, id; - struct linger ling; struct sockaddr_in addr; int iret, len; DWORD dwret; @@ -6621,19 +6668,13 @@ static void test_WSARecv(void) if (!event) goto end; - ling.l_onoff = 1; - ling.l_linger = 0; - iret = setsockopt (src, SOL_SOCKET, SO_LINGER, (char *) &ling, sizeof(ling)); - ok(!iret, "Failed to set linger %ld\n", GetLastError()); - iret = WSARecv(dest, bufs, 1, NULL, &flags, &ov, NULL); ok(iret == SOCKET_ERROR && GetLastError() == ERROR_IO_PENDING, "WSARecv failed - %d error %ld\n", iret, GetLastError()); iret = WSARecv(dest, bufs, 1, &bytesReturned, &flags, &ov, NULL); ok(iret == SOCKET_ERROR && GetLastError() == ERROR_IO_PENDING, "WSARecv failed - %d error %ld\n", iret, GetLastError()); - closesocket(src); - src = INVALID_SOCKET; + close_with_rst(src); dwret = WaitForSingleObject(ov.hEvent, 1000); ok(dwret == WAIT_OBJECT_0, "Waiting for disconnect event failed with %ld + errno %ld\n", dwret, GetLastError()); @@ -9239,7 +9280,6 @@ static void test_completion_port(void) char buf[1024]; WSABUF bufs; DWORD num_bytes, flags; - struct linger ling; int iret; BOOL bret; ULONG_PTR key; @@ -9260,11 +9300,6 @@ static void test_completion_port(void) bufs.buf = buf; flags = 0; - ling.l_onoff = 1; - ling.l_linger = 0; - iret = setsockopt (src, SOL_SOCKET, SO_LINGER, (char *) &ling, sizeof(ling)); - ok(!iret, "Failed to set linger %ld\n", GetLastError()); - io_port = CreateIoCompletionPort( (HANDLE)dest, io_port, 125, 0 ); ok(io_port != NULL, "Failed to create completion port %lu\n", GetLastError()); @@ -9276,8 +9311,7 @@ static void test_completion_port(void) Sleep(100); - closesocket(src); - src = INVALID_SOCKET; + close_with_rst(src); SetLastError(0xdeadbeef); key = 0xdeadbeef; @@ -9314,18 +9348,12 @@ static void test_completion_port(void) bufs.buf = buf; flags = 0; - ling.l_onoff = 1; - ling.l_linger = 0; - iret = setsockopt (src, SOL_SOCKET, SO_LINGER, (char *) &ling, sizeof(ling)); - ok(!iret, "Failed to set linger %ld\n", GetLastError()); - io_port = CreateIoCompletionPort((HANDLE)dest, io_port, 125, 0); ok(io_port != NULL, "failed to create completion port %lu\n", GetLastError()); set_blocking(dest, FALSE); - closesocket(src); - src = INVALID_SOCKET; + close_with_rst(src); Sleep(100); @@ -9432,17 +9460,11 @@ static void test_completion_port(void) flags = 0; memset(&ov, 0, sizeof(ov)); - ling.l_onoff = 1; - ling.l_linger = 0; - iret = setsockopt (src, SOL_SOCKET, SO_LINGER, (char *) &ling, sizeof(ling)); - ok(!iret, "Failed to set linger %ld\n", GetLastError()); - io_port = CreateIoCompletionPort((HANDLE)dest, io_port, 125, 0); ok(io_port != NULL, "failed to create completion port %lu\n", GetLastError()); set_blocking(dest, FALSE); - closesocket(src); - src = INVALID_SOCKET; + close_with_rst(src); FD_ZERO(&fds_recv); FD_SET(dest, &fds_recv); @@ -12441,6 +12463,76 @@ static void test_sockopt_validity(void) CloseHandle(file); } +static void test_tcp_reset(void) +{ + static const struct timeval select_timeout; + fd_set readfds, writefds, exceptfds; + OVERLAPPED overlapped = {0}; + SOCKET client, server; + DWORD size, flags = 0; + int ret, len, error; + char buffer[10]; + WSABUF wsabuf; + + overlapped.hEvent = CreateEventW(NULL, TRUE, FALSE, NULL); + + tcp_socketpair(&client, &server); + + wsabuf.buf = buffer; + wsabuf.len = sizeof(buffer); + WSASetLastError(0xdeadbeef); + size = 0xdeadbeef; + ret = WSARecv(client, &wsabuf, 1, &size, &flags, &overlapped, NULL); + ok(ret == -1, "got %d\n", ret); + ok(WSAGetLastError() == ERROR_IO_PENDING, "got error %u\n", WSAGetLastError()); + + close_with_rst(server); + + ret = WaitForSingleObject(overlapped.hEvent, 1000); + ok(!ret, "wait failed\n"); + ret = GetOverlappedResult((HANDLE)client, &overlapped, &size, FALSE); + todo_wine ok(!ret, "expected failure\n"); + todo_wine ok(GetLastError() == ERROR_NETNAME_DELETED, "got error %lu\n", GetLastError()); + ok(!size, "got size %lu\n", size); + todo_wine ok((NTSTATUS)overlapped.Internal == STATUS_CONNECTION_RESET, "got status %#lx\n", (NTSTATUS)overlapped.Internal); + + len = sizeof(error); + ret = getsockopt(client, SOL_SOCKET, SO_ERROR, (char *)&error, &len); + ok(!ret, "got error %u\n", WSAGetLastError()); + todo_wine ok(!error, "got error %u\n", error); + + wsabuf.buf = buffer; + wsabuf.len = sizeof(buffer); + WSASetLastError(0xdeadbeef); + size = 0xdeadbeef; + ret = WSARecv(client, &wsabuf, 1, &size, &flags, &overlapped, NULL); + todo_wine ok(ret == -1, "got %d\n", ret); + todo_wine ok(WSAGetLastError() == WSAECONNRESET, "got error %u\n", WSAGetLastError()); + + check_poll_todo(client, POLLERR | POLLHUP | POLLWRNORM); + + FD_ZERO(&readfds); + FD_ZERO(&writefds); + FD_ZERO(&exceptfds); + FD_SET(client, &readfds); + FD_SET(client, &writefds); + FD_SET(client, &exceptfds); + ret = select(0, &readfds, &writefds, &exceptfds, &select_timeout); + ok(ret == 2, "got %d\n", ret); + ok(FD_ISSET(client, &readfds), "FD should be set\n"); + ok(FD_ISSET(client, &writefds), "FD should be set\n"); + ok(!FD_ISSET(client, &exceptfds), "FD should be set\n"); + + FD_ZERO(&exceptfds); + FD_SET(client, &exceptfds); + ret = select(0, NULL, NULL, &exceptfds, &select_timeout); + ok(!ret, "got %d\n", ret); + ok(!FD_ISSET(client, &exceptfds), "FD should be set\n"); + + closesocket(server); + CloseHandle(overlapped.hEvent); +} + START_TEST( sock ) { int i; @@ -12514,6 +12606,7 @@ START_TEST( sock ) test_simultaneous_async_recv(); test_empty_recv(); test_timeout(); + test_tcp_reset(); /* this is an io heavy test, do it at the end so the kernel doesn't start dropping packets */ test_send();