From 5990f19bdc174032a4c6b459204c32ba0f1f3354 Mon Sep 17 00:00:00 2001 From: Guillaume Charifi Date: Thu, 23 Sep 2021 00:48:03 -0500 Subject: [PATCH] ntdll: Implement exclusive flag for IOCTL_AFD_POLL. Signed-off-by: Guillaume Charifi Signed-off-by: Zebediah Figura Signed-off-by: Alexandre Julliard --- dlls/ntdll/unix/socket.c | 2 +- dlls/ws2_32/tests/afd.c | 10 +++--- include/wine/server_protocol.h | 4 +-- server/protocol.def | 1 + server/request.h | 1 + server/sock.c | 66 ++++++++++++++++++++++++++++++++-- server/trace.c | 3 +- 7 files changed, 75 insertions(+), 12 deletions(-) diff --git a/dlls/ntdll/unix/socket.c b/dlls/ntdll/unix/socket.c index c1a29f5d5e0..75c54295c20 100644 --- a/dlls/ntdll/unix/socket.c +++ b/dlls/ntdll/unix/socket.c @@ -741,7 +741,6 @@ static NTSTATUS sock_poll( HANDLE handle, HANDLE event, PIO_APC_ROUTINE apc, voi params->padding[0], params->padding[1], params->padding[2], params->sockets[0].socket, params->sockets[0].flags ); - if (params->exclusive) FIXME( "ignoring exclusive flag\n" ); if (params->padding[0]) FIXME( "padding[0] is %#x\n", params->padding[0] ); if (params->padding[1]) FIXME( "padding[1] is %#x\n", params->padding[1] ); if (params->padding[2]) FIXME( "padding[2] is %#x\n", params->padding[2] ); @@ -782,6 +781,7 @@ static NTSTATUS sock_poll( HANDLE handle, HANDLE event, PIO_APC_ROUTINE apc, voi SERVER_START_REQ( poll_socket ) { req->async = server_async( handle, &async->io, event, apc, apc_user, iosb_client_ptr(io) ); + req->exclusive = !!params->exclusive; req->timeout = params->timeout; wine_server_add_data( req, input, params->count * sizeof(*input) ); wine_server_set_reply( req, async->sockets, params->count * sizeof(async->sockets[0]) ); diff --git a/dlls/ws2_32/tests/afd.c b/dlls/ws2_32/tests/afd.c index 647b7e41f63..9e888f1fcb9 100644 --- a/dlls/ws2_32/tests/afd.c +++ b/dlls/ws2_32/tests/afd.c @@ -867,7 +867,7 @@ static void test_poll_exclusive(void) ok(ret == STATUS_PENDING, "got %#x\n", ret); ret = WaitForSingleObject(events[0], 100); - todo_wine ok(ret == STATUS_SUCCESS, "got %#x\n", ret); + ok(ret == STATUS_SUCCESS, "got %#x\n", ret); ret = WaitForSingleObject(events[1], 100); ok(ret == STATUS_TIMEOUT, "got %#x\n", ret); @@ -927,7 +927,7 @@ static void test_poll_exclusive(void) ok(ret == STATUS_PENDING, "got %#x\n", ret); ret = WaitForSingleObject(events[1], 100); - todo_wine ok(ret == STATUS_SUCCESS, "got %#x\n", ret); + ok(ret == STATUS_SUCCESS, "got %#x\n", ret); ret = WaitForSingleObject(events[2], 100); ok(ret == STATUS_TIMEOUT, "got %#x\n", ret); @@ -997,7 +997,7 @@ static void test_poll_exclusive(void) ok(ret == STATUS_TIMEOUT, "got %#x\n", ret); ret = WaitForSingleObject(events[2], 100); - todo_wine ok(ret == STATUS_SUCCESS, "got %#x\n", ret); + ok(ret == STATUS_SUCCESS, "got %#x\n", ret); ret = WaitForSingleObject(events[3], 100); ok(ret == STATUS_TIMEOUT, "got %#x\n", ret); @@ -1048,7 +1048,7 @@ static void test_poll_exclusive(void) CloseHandle(thrd); ret = WaitForSingleObject(events[0], 100); - todo_wine ok(ret == STATUS_SUCCESS, "got %#x\n", ret); + ok(ret == STATUS_SUCCESS, "got %#x\n", ret); CancelIo((HANDLE)ctl_sock); @@ -1073,7 +1073,7 @@ static void test_poll_exclusive(void) ok(ret == STATUS_PENDING, "got %#x\n", ret); ret = WaitForSingleObject(events[0], 100); - todo_wine ok(ret == STATUS_SUCCESS, "got %#x\n", ret); + ok(ret == STATUS_SUCCESS, "got %#x\n", ret); ret = WaitForSingleObject(events[1], 100); ok(ret == STATUS_TIMEOUT, "got %#x\n", ret); diff --git a/include/wine/server_protocol.h b/include/wine/server_protocol.h index e9c12cac070..02c90ea9011 100644 --- a/include/wine/server_protocol.h +++ b/include/wine/server_protocol.h @@ -1762,7 +1762,7 @@ struct poll_socket_output struct poll_socket_request { struct request_header __header; - char __pad_12[4]; + int exclusive; async_data_t async; timeout_t timeout; /* VARARG(sockets,poll_socket_input); */ @@ -6257,7 +6257,7 @@ union generic_reply /* ### protocol_version begin ### */ -#define SERVER_PROTOCOL_VERSION 731 +#define SERVER_PROTOCOL_VERSION 732 /* ### protocol_version end ### */ diff --git a/server/protocol.def b/server/protocol.def index a5030fcf813..a7a4b7cc957 100644 --- a/server/protocol.def +++ b/server/protocol.def @@ -1451,6 +1451,7 @@ struct poll_socket_output /* Perform an async poll on a socket */ @REQ(poll_socket) + int exclusive; /* is the poll exclusive? */ async_data_t async; /* async I/O parameters */ timeout_t timeout; /* timeout */ VARARG(sockets,poll_socket_input); /* list of sockets to poll */ diff --git a/server/request.h b/server/request.h index 6fe708e065f..ff1744c1a80 100644 --- a/server/request.h +++ b/server/request.h @@ -1045,6 +1045,7 @@ C_ASSERT( sizeof(struct recv_socket_request) == 64 ); C_ASSERT( FIELD_OFFSET(struct recv_socket_reply, wait) == 8 ); C_ASSERT( FIELD_OFFSET(struct recv_socket_reply, options) == 12 ); C_ASSERT( sizeof(struct recv_socket_reply) == 16 ); +C_ASSERT( FIELD_OFFSET(struct poll_socket_request, exclusive) == 12 ); C_ASSERT( FIELD_OFFSET(struct poll_socket_request, async) == 16 ); C_ASSERT( FIELD_OFFSET(struct poll_socket_request, timeout) == 56 ); C_ASSERT( sizeof(struct poll_socket_request) == 64 ); diff --git a/server/sock.c b/server/sock.c index e426019d558..7b00cb3f4f4 100644 --- a/server/sock.c +++ b/server/sock.c @@ -128,6 +128,7 @@ struct poll_req struct async *async; struct iosb *iosb; struct timeout_user *timeout; + int exclusive; unsigned int count; struct poll_socket_output *output; struct @@ -205,6 +206,7 @@ struct sock struct list accept_list; /* list of pending accept requests */ struct accept_req *accept_recv_req; /* pending accept-into request which will recv on this socket */ struct connect_req *connect_req; /* pending connection request */ + struct poll_req *main_poll; /* main poll */ union win_sockaddr addr; /* socket name */ int addr_len; /* socket name length */ unsigned int rcvbuf; /* advisory recv buffer size */ @@ -231,6 +233,7 @@ static int sock_get_poll_events( struct fd *fd ); static void sock_poll_event( struct fd *fd, int event ); static enum server_fd_type sock_get_fd_type( struct fd *fd ); static void sock_ioctl( struct fd *fd, ioctl_code_t code, struct async *async ); +static void sock_cancel_async( struct fd *fd, struct async *async ); static void sock_queue_async( struct fd *fd, struct async *async, int type, int count ); static void sock_reselect_async( struct fd *fd, struct async_queue *queue ); @@ -274,7 +277,7 @@ static const struct fd_ops sock_fd_ops = default_fd_get_file_info, /* get_file_info */ no_fd_get_volume_info, /* get_volume_info */ sock_ioctl, /* ioctl */ - default_fd_cancel_async, /* cancel_async */ + sock_cancel_async, /* cancel_async */ sock_queue_async, /* queue_async */ sock_reselect_async /* reselect_async */ }; @@ -834,6 +837,16 @@ static int get_poll_flags( struct sock *sock, int event ) static void complete_async_poll( struct poll_req *req, unsigned int status ) { + unsigned int i; + + for (i = 0; i < req->count; ++i) + { + struct sock *sock = req->sockets[i].sock; + + if (sock->main_poll == req) + sock->main_poll = NULL; + } + /* pass 0 as result; client will set actual result size */ async_request_complete( req->async, status, 0, req->count * sizeof(*req->output), req->output ); } @@ -1222,6 +1235,29 @@ static enum server_fd_type sock_get_fd_type( struct fd *fd ) return FD_TYPE_SOCKET; } +static void sock_cancel_async( struct fd *fd, struct async *async ) +{ + struct poll_req *req; + + LIST_FOR_EACH_ENTRY( req, &poll_list, struct poll_req, entry ) + { + unsigned int i; + + if (req->async != async) + continue; + + for (i = 0; i < req->count; i++) + { + struct sock *sock = req->sockets[i].sock; + + if (sock->main_poll == req) + sock->main_poll = NULL; + } + } + + async_terminate( async, STATUS_CANCELLED ); +} + static void sock_queue_async( struct fd *fd, struct async *async, int type, int count ) { struct sock *sock = get_fd_user( fd ); @@ -1383,6 +1419,7 @@ static struct sock *create_socket(void) sock->ifchange_obj = NULL; sock->accept_recv_req = NULL; sock->connect_req = NULL; + sock->main_poll = NULL; memset( &sock->addr, 0, sizeof(sock->addr) ); sock->addr_len = 0; sock->rd_shutdown = 0; @@ -2840,7 +2877,27 @@ static int poll_single_socket( struct sock *sock, int mask ) return get_poll_flags( sock, pollfd.revents ) & mask; } -static void poll_socket( struct sock *poll_sock, struct async *async, timeout_t timeout, +static void handle_exclusive_poll(struct poll_req *req) +{ + unsigned int i; + + for (i = 0; i < req->count; ++i) + { + struct sock *sock = req->sockets[i].sock; + struct poll_req *main_poll = sock->main_poll; + + if (main_poll && main_poll->exclusive && req->exclusive) + { + complete_async_poll( main_poll, STATUS_SUCCESS ); + main_poll = NULL; + } + + if (!main_poll) + sock->main_poll = req; + } +} + +static void poll_socket( struct sock *poll_sock, struct async *async, int exclusive, timeout_t timeout, unsigned int count, const struct poll_socket_input *input ) { struct poll_socket_output *output; @@ -2881,11 +2938,14 @@ static void poll_socket( struct sock *poll_sock, struct async *async, timeout_t req->sockets[i].flags = input[i].flags; } + req->exclusive = exclusive; req->count = count; req->async = (struct async *)grab_object( async ); req->iosb = async_get_iosb( async ); req->output = output; + handle_exclusive_poll(req); + list_add_tail( &poll_list, &req->entry ); async_set_completion_callback( async, free_poll_req, req ); queue_async( &poll_sock->poll_q, async ); @@ -3287,7 +3347,7 @@ DECL_HANDLER(poll_socket) if ((async = create_request_async( sock->fd, get_fd_comp_flags( sock->fd ), &req->async ))) { - poll_socket( sock, async, req->timeout, count, input ); + poll_socket( sock, async, req->exclusive, req->timeout, count, input ); reply->wait = async_handoff( async, NULL, 0 ); reply->options = get_fd_options( sock->fd ); release_object( async ); diff --git a/server/trace.c b/server/trace.c index 6e380fc4be5..11e3eb25da6 100644 --- a/server/trace.c +++ b/server/trace.c @@ -2122,7 +2122,8 @@ static void dump_recv_socket_reply( const struct recv_socket_reply *req ) static void dump_poll_socket_request( const struct poll_socket_request *req ) { - dump_async_data( " async=", &req->async ); + fprintf( stderr, " exclusive=%d", req->exclusive ); + dump_async_data( ", async=", &req->async ); dump_timeout( ", timeout=", &req->timeout ); dump_varargs_poll_socket_input( ", sockets=", cur_size ); }