diff --git a/dlls/ws2_32/socket.c b/dlls/ws2_32/socket.c index bfc4bb6c09d..f5958e86e9f 100644 --- a/dlls/ws2_32/socket.c +++ b/dlls/ws2_32/socket.c @@ -510,17 +510,6 @@ static const int ws_socktype_map[][2] = {FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO}, }; -static const int ws_poll_map[][2] = -{ - MAP_OPTION( POLLERR ), - MAP_OPTION( POLLHUP ), - MAP_OPTION( POLLNVAL ), - MAP_OPTION( POLLWRNORM ), - MAP_OPTION( POLLWRBAND ), - MAP_OPTION( POLLRDNORM ), - { WS_POLLRDBAND, POLLPRI } -}; - UINT sock_get_error( int err ) { switch(err) @@ -988,40 +977,6 @@ convert_socktype_u2w(int unixsocktype) { return -1; } -static int convert_poll_w2u(int events) -{ - int i, ret; - for (i = ret = 0; events && i < ARRAY_SIZE(ws_poll_map); i++) - { - if (ws_poll_map[i][0] & events) - { - ret |= ws_poll_map[i][1]; - events &= ~ws_poll_map[i][0]; - } - } - - if (events) - FIXME("Unsupported WSAPoll() flags 0x%x\n", events); - return ret; -} - -static int convert_poll_u2w(int events) -{ - int i, ret; - for (i = ret = 0; events && i < ARRAY_SIZE(ws_poll_map); i++) - { - if (ws_poll_map[i][1] & events) - { - ret |= ws_poll_map[i][0]; - events &= ~ws_poll_map[i][1]; - } - } - - if (events) - FIXME("Unsupported poll() flags 0x%x\n", events); - return ret; -} - static int set_ipx_packettype(int sock, int ptype) { #ifdef HAS_IPX @@ -3531,65 +3486,120 @@ static unsigned int afd_poll_flag_to_win32( unsigned int flags ) /*********************************************************************** - * WSAPoll + * WSAPoll (ws2_32.@) */ -int WINAPI WSAPoll(WSAPOLLFD *wfds, ULONG count, int timeout) +int WINAPI WSAPoll( WSAPOLLFD *fds, ULONG count, int timeout ) { - int i, ret; - struct pollfd *ufds; + struct afd_poll_params *params; + ULONG params_size, i, j; + SOCKET poll_socket = 0; + IO_STATUS_BLOCK io; + HANDLE sync_event; + int ret_count = 0; + NTSTATUS status; if (!count) { SetLastError(WSAEINVAL); return SOCKET_ERROR; } - if (!wfds) + if (!fds) { SetLastError(WSAEFAULT); return SOCKET_ERROR; } - if (!(ufds = HeapAlloc(GetProcessHeap(), 0, count * sizeof(ufds[0])))) + if (!(sync_event = get_sync_event())) return -1; + + params_size = offsetof( struct afd_poll_params, sockets[count] ); + if (!(params = HeapAlloc( GetProcessHeap(), HEAP_ZERO_MEMORY, params_size ))) { SetLastError(WSAENOBUFS); return SOCKET_ERROR; } - for (i = 0; i < count; i++) - { - ufds[i].fd = get_sock_fd(wfds[i].fd, 0, NULL); - ufds[i].events = convert_poll_w2u(wfds[i].events); - ufds[i].revents = 0; - } + params->timeout = (timeout >= 0 ? timeout * -10000 : TIMEOUT_INFINITE); - ret = do_poll(ufds, count, timeout); - - for (i = 0; i < count; i++) + for (i = 0; i < count; ++i) { - if (ufds[i].fd != -1) + unsigned int flags = AFD_POLL_HUP | AFD_POLL_RESET | AFD_POLL_CONNECT_ERR; + + if ((INT_PTR)fds[i].fd < 0 || !socket_list_find( fds[i].fd )) { - release_sock_fd(wfds[i].fd, ufds[i].fd); - if (ufds[i].revents & POLLHUP) - { - /* Check if the socket still exists */ - int fd = get_sock_fd(wfds[i].fd, 0, NULL); - if (fd != -1) - { - wfds[i].revents = WS_POLLHUP; - release_sock_fd(wfds[i].fd, fd); - } - else - wfds[i].revents = WS_POLLNVAL; - } - else - wfds[i].revents = convert_poll_u2w(ufds[i].revents); + fds[i].revents = WS_POLLNVAL; + continue; } - else - wfds[i].revents = WS_POLLNVAL; + + poll_socket = fds[i].fd; + params->sockets[params->count].socket = fds[i].fd; + + if (fds[i].events & WS_POLLRDNORM) + flags |= AFD_POLL_ACCEPT | AFD_POLL_READ; + if (fds[i].events & WS_POLLRDBAND) + flags |= AFD_POLL_OOB; + if (fds[i].events & WS_POLLWRNORM) + flags |= AFD_POLL_WRITE; + params->sockets[params->count].flags = flags; + ++params->count; + + fds[i].revents = 0; } - HeapFree(GetProcessHeap(), 0, ufds); - return ret; + if (!poll_socket) + { + SetLastError( WSAENOTSOCK ); + HeapFree( GetProcessHeap(), 0, params ); + return -1; + } + + status = NtDeviceIoControlFile( (HANDLE)poll_socket, sync_event, NULL, NULL, &io, IOCTL_AFD_POLL, + params, params_size, params, params_size ); + if (status == STATUS_PENDING) + { + if (WaitForSingleObject( sync_event, INFINITE ) == WAIT_FAILED) + { + HeapFree( GetProcessHeap(), 0, params ); + return -1; + } + status = io.u.Status; + } + if (!status) + { + for (i = 0; i < count; ++i) + { + for (j = 0; j < params->count; ++j) + { + if (fds[i].fd == params->sockets[j].socket) + { + unsigned int revents = 0; + + if (params->sockets[j].flags & (AFD_POLL_ACCEPT | AFD_POLL_READ)) + revents |= WS_POLLRDNORM; + if (params->sockets[j].flags & AFD_POLL_OOB) + revents |= WS_POLLRDBAND; + if (params->sockets[j].flags & AFD_POLL_WRITE) + revents |= WS_POLLWRNORM; + if (params->sockets[j].flags & AFD_POLL_HUP) + revents |= WS_POLLHUP; + if (params->sockets[j].flags & (AFD_POLL_RESET | AFD_POLL_CONNECT_ERR)) + revents |= WS_POLLERR; + if (params->sockets[j].flags & AFD_POLL_CLOSE) + revents |= WS_POLLNVAL; + + fds[i].revents = revents & (fds[i].events | WS_POLLHUP | WS_POLLERR | WS_POLLNVAL); + + if (fds[i].revents) + ++ret_count; + } + } + } + } + if (status == STATUS_TIMEOUT) status = STATUS_SUCCESS; + + HeapFree( GetProcessHeap(), 0, params ); + + SetLastError( NtStatusToWSAError( status ) ); + return status ? -1 : ret_count; } diff --git a/dlls/ws2_32/tests/sock.c b/dlls/ws2_32/tests/sock.c index b171e6c03e1..7b92d0b4e84 100644 --- a/dlls/ws2_32/tests/sock.c +++ b/dlls/ws2_32/tests/sock.c @@ -6064,7 +6064,7 @@ static void test_WSAPoll(void) WSASetLastError(0xdeadbeef); ret = pWSAPoll(fds, 2, 0); ok(!ret, "got %d\n", ret); - todo_wine ok(!WSAGetLastError(), "got error %u\n", WSAGetLastError()); + ok(!WSAGetLastError(), "got error %u\n", WSAGetLastError()); ok(fds[0].revents == POLLNVAL, "got events %#x\n", fds[0].revents); ok(!fds[1].revents, "got events %#x\n", fds[1].revents); @@ -6077,7 +6077,7 @@ static void test_WSAPoll(void) WSASetLastError(0xdeadbeef); ret = pWSAPoll(fds, 2, 0); ok(!ret, "got %d\n", ret); - todo_wine ok(!WSAGetLastError(), "got error %u\n", WSAGetLastError()); + ok(!WSAGetLastError(), "got error %u\n", WSAGetLastError()); ok(!fds[0].revents, "got events %#x\n", fds[0].revents); ok(fds[1].revents == POLLNVAL, "got events %#x\n", fds[1].revents); @@ -6102,7 +6102,7 @@ static void test_WSAPoll(void) fds[1].revents = 0xdead; WSASetLastError(0xdeadbeef); ret = pWSAPoll(fds, 2, 0); - todo_wine ok(ret == -1, "got %d\n", ret); + ok(ret == -1, "got %d\n", ret); ok(WSAGetLastError() == WSAENOTSOCK, "got error %u\n", WSAGetLastError()); ok(fds[0].revents == POLLNVAL, "got events %#x\n", fds[0].revents); ok(fds[1].revents == POLLNVAL, "got events %#x\n", fds[1].revents); @@ -6222,8 +6222,8 @@ static void test_WSAPoll(void) ok(ret == 1, "got %d\n", ret); check_poll(client, POLLWRNORM); - check_poll_mask_todo(server, POLLRDNORM | POLLRDBAND, POLLRDNORM); - check_poll_todo(server, POLLWRNORM | POLLRDNORM); + check_poll_mask(server, POLLRDNORM | POLLRDBAND, POLLRDNORM); + check_poll(server, POLLWRNORM | POLLRDNORM); buffer[0] = 0xcc; ret = recv(server, buffer, 1, 0); @@ -6261,8 +6261,8 @@ static void test_WSAPoll(void) closesocket(client); - check_poll_mask_todo(server, 0, POLLHUP); - check_poll_todo(server, POLLWRNORM | POLLHUP); + check_poll_mask(server, 0, POLLHUP); + check_poll(server, POLLWRNORM | POLLHUP); closesocket(server);