diff --git a/dlls/ntdll/tests/file.c b/dlls/ntdll/tests/file.c index e299560eb30..f6a9e7c280b 100644 --- a/dlls/ntdll/tests/file.c +++ b/dlls/ntdll/tests/file.c @@ -1107,6 +1107,50 @@ static void test_iocp_fileio(HANDLE h) } CloseHandle( hPipeClt ); + + /* test associating a completion port with a handle after an async is queued */ + hPipeSrv = CreateNamedPipeA( pipe_name, PIPE_ACCESS_INBOUND | FILE_FLAG_OVERLAPPED, PIPE_TYPE_MESSAGE | PIPE_READMODE_MESSAGE | PIPE_WAIT, 4, 1024, 1024, 1000, NULL ); + ok( hPipeSrv != INVALID_HANDLE_VALUE, "Cannot create named pipe\n" ); + if (hPipeSrv == INVALID_HANDLE_VALUE ) + return; + hPipeClt = CreateFileA( pipe_name, GENERIC_WRITE, 0, NULL, OPEN_EXISTING, FILE_FLAG_NO_BUFFERING | FILE_FLAG_OVERLAPPED, NULL ); + ok( hPipeClt != INVALID_HANDLE_VALUE, "Cannot connect to pipe\n" ); + if (hPipeClt != INVALID_HANDLE_VALUE) + { + OVERLAPPED o = {0,}; + BYTE send_buf[TEST_BUF_LEN], recv_buf[TEST_BUF_LEN]; + DWORD read; + long count; + + memset( send_buf, 0, TEST_BUF_LEN ); + memset( recv_buf, 0xde, TEST_BUF_LEN ); + count = get_pending_msgs(h); + ok( !count, "Unexpected msg count: %ld\n", count ); + ReadFile( hPipeSrv, recv_buf, TEST_BUF_LEN, &read, &o); + + U(iosb).Status = 0xdeadbeef; + res = pNtSetInformationFile( hPipeSrv, &iosb, &fci, sizeof(fci), FileCompletionInformation ); + ok( res == STATUS_SUCCESS, "NtSetInformationFile failed: %x\n", res ); + ok( U(iosb).Status == STATUS_SUCCESS, "iosb.Status invalid: %x\n", U(iosb).Status ); + count = get_pending_msgs(h); + ok( !count, "Unexpected msg count: %ld\n", count ); + + WriteFile( hPipeClt, send_buf, TEST_BUF_LEN, &read, NULL ); + + if (get_msg(h)) + { + ok( completionKey == CKEY_SECOND, "Invalid completion key: %lx\n", completionKey ); + ok( ioSb.Information == 3, "Invalid ioSb.Information: %ld\n", ioSb.Information ); + ok( U(ioSb).Status == STATUS_SUCCESS, "Invalid ioSb.Status: %x\n", U(ioSb).Status); + ok( completionValue == (ULONG_PTR)&o, "Invalid completion value: %lx\n", completionValue ); + ok( !memcmp( send_buf, recv_buf, TEST_BUF_LEN ), "Receive buffer (%x %x %x) did not match send buffer (%x %x %x)\n", recv_buf[0], recv_buf[1], recv_buf[2], send_buf[0], send_buf[1], send_buf[2] ); + } + count = get_pending_msgs(h); + ok( !count, "Unexpected msg count: %ld\n", count ); + } + + CloseHandle( hPipeSrv ); + CloseHandle( hPipeClt ); } static void test_file_basic_information(void) diff --git a/server/async.c b/server/async.c index dd28dfffccd..843a02bbc87 100644 --- a/server/async.c +++ b/server/async.c @@ -42,8 +42,6 @@ struct async struct timeout_user *timeout; unsigned int timeout_status; /* status to report upon timeout */ struct event *event; - struct completion *completion; - apc_param_t comp_key; async_data_t data; /* data for async I/O call */ }; @@ -75,10 +73,13 @@ struct async_queue { struct object obj; /* object header */ struct fd *fd; /* file descriptor owning this queue */ + struct completion *completion; /* completion associated with a recently closed file descriptor */ + apc_param_t comp_key; /* completion key associated with a recently closed file descriptor */ struct list queue; /* queue of async objects */ }; static void async_queue_dump( struct object *obj, int verbose ); +static void async_queue_destroy( struct object *obj ); static const struct object_ops async_queue_ops = { @@ -97,7 +98,7 @@ static const struct object_ops async_queue_ops = no_lookup_name, /* lookup_name */ no_open_file, /* open_file */ no_close_handle, /* close_handle */ - no_destroy /* destroy */ + async_queue_destroy /* destroy */ }; @@ -123,7 +124,6 @@ static void async_destroy( struct object *obj ) if (async->timeout) remove_timeout_user( async->timeout ); if (async->event) release_object( async->event ); - if (async->completion) release_object( async->completion ); release_object( async->queue ); release_object( async->thread ); } @@ -135,6 +135,13 @@ static void async_queue_dump( struct object *obj, int verbose ) fprintf( stderr, "Async queue fd=%p\n", async_queue->fd ); } +static void async_queue_destroy( struct object *obj ) +{ + struct async_queue *async_queue = (struct async_queue *)obj; + assert( obj->ops == &async_queue_ops ); + if (async_queue->completion) release_object( async_queue->completion ); +} + /* notifies client thread of new status of its async request */ void async_terminate( struct async *async, unsigned int status ) { @@ -178,6 +185,7 @@ struct async_queue *create_async_queue( struct fd *fd ) if (queue) { queue->fd = fd; + queue->completion = NULL; list_init( &queue->queue ); } return queue; @@ -187,6 +195,7 @@ struct async_queue *create_async_queue( struct fd *fd ) void free_async_queue( struct async_queue *queue ) { if (!queue) return; + if (queue->fd) queue->completion = fd_get_completion( queue->fd, &queue->comp_key ); queue->fd = NULL; async_wake_up( queue, STATUS_HANDLES_CLOSED ); release_object( queue ); @@ -213,8 +222,6 @@ struct async *create_async( struct thread *thread, struct async_queue *queue, co async->data = *data; async->timeout = NULL; async->queue = (struct async_queue *)grab_object( queue ); - async->completion = NULL; - if (queue->fd) async->completion = fd_get_completion( queue->fd, &async->comp_key ); list_add_tail( &queue->queue, &async->queue_entry ); grab_object( async ); @@ -233,6 +240,24 @@ void async_set_timeout( struct async *async, timeout_t timeout, unsigned int sta async->timeout_status = status; } +static void add_async_completion( struct async_queue *queue, apc_param_t cvalue, unsigned int status, + unsigned int information ) +{ + if (queue->fd) + { + apc_param_t ckey; + struct completion *completion = fd_get_completion( queue->fd, &ckey ); + + if (completion) + { + add_completion( completion, ckey, cvalue, status, information ); + release_object( completion ); + } + } + else if (queue->completion) add_completion( queue->completion, queue->comp_key, + cvalue, status, information ); +} + /* store the result of the client-side async callback */ void async_set_result( struct object *obj, unsigned int status, unsigned int total, client_ptr_t apc ) { @@ -258,8 +283,7 @@ void async_set_result( struct object *obj, unsigned int status, unsigned int tot if (async->timeout) remove_timeout_user( async->timeout ); async->timeout = NULL; async->status = status; - if (async->completion && async->data.cvalue) - add_completion( async->completion, async->comp_key, async->data.cvalue, status, total ); + if (async->data.cvalue) add_async_completion( async->queue, async->data.cvalue, status, total ); if (apc) { apc_call_t data;