From a20c4e11dfe4d89fc0f7184616386e56354995df Mon Sep 17 00:00:00 2001 From: Alexandre Julliard Date: Fri, 22 Sep 2017 14:58:09 +0200 Subject: [PATCH] ntdll: Add a helper to lock the user output buffer during a server call. Signed-off-by: Alexandre Julliard --- dlls/kernel32/tests/virtual.c | 10 ++++------ dlls/ntdll/file.c | 6 +++--- dlls/ntdll/ntdll_misc.h | 2 ++ dlls/ntdll/server.c | 17 ++++++++++++++--- dlls/ntdll/virtual.c | 25 +++++++++++++++++++++++++ 5 files changed, 48 insertions(+), 12 deletions(-) diff --git a/dlls/kernel32/tests/virtual.c b/dlls/kernel32/tests/virtual.c index 1d611cdacff..ec2489f97a1 100644 --- a/dlls/kernel32/tests/virtual.c +++ b/dlls/kernel32/tests/virtual.c @@ -1840,7 +1840,7 @@ static DWORD CALLBACK read_pipe( void *arg ) "%u: ConnectNamedPipe failed %u\n", args->index, GetLastError() ); success = ReadFile( args->pipe, args->base, args->size, &num_bytes, NULL ); - todo_wine + todo_wine_if (!args->index) { ok( success, "%u: ReadFile failed %u\n", args->index, GetLastError() ); ok( num_bytes == sizeof(testdata), "%u: wrong number of bytes read %u\n", args->index, num_bytes ); @@ -2056,7 +2056,7 @@ static void test_write_watch(void) num_bytes = 0; success = GetOverlappedResult( readpipe, &overlapped, &num_bytes, TRUE ); - todo_wine + todo_wine_if (!i) { ok( success, "%u: GetOverlappedResult failed %u\n", i, GetLastError() ); ok( num_bytes == sizeof(testdata), "%u: wrong number of bytes read %u\n", i, num_bytes ); @@ -2067,7 +2067,7 @@ static void test_write_watch(void) memset( results, 0, sizeof(results) ); ret = pGetWriteWatch( WRITE_WATCH_FLAG_RESET, base, size, results, &count, &pagesize ); ok( !ret, "%u: GetWriteWatch failed %u\n", i, GetLastError() ); - todo_wine + todo_wine_if (!i) { ok( count == 1, "%u: wrong count %lu\n", i, count ); ok( results[0] == base, "%u: wrong result %p\n", i, results[0] ); @@ -2076,7 +2076,6 @@ static void test_write_watch(void) CloseHandle( readpipe ); CloseHandle( writepipe ); CloseHandle( overlapped.hEvent ); - if (!success) break; /* don't try message mode if byte mode already doesn't work */ } for (i = 0; i < 2; i++) @@ -2121,7 +2120,7 @@ static void test_write_watch(void) memset( results, 0, sizeof(results) ); ret = pGetWriteWatch( WRITE_WATCH_FLAG_RESET, base, size, results, &count, &pagesize ); ok( !ret, "%u: GetWriteWatch failed %u\n", i, GetLastError() ); - todo_wine + todo_wine_if (!i) { ok( count == 1, "%u: wrong count %lu\n", i, count ); ok( results[0] == base, "%u: wrong result %p\n", i, results[0] ); @@ -2130,7 +2129,6 @@ static void test_write_watch(void) CloseHandle( readpipe ); CloseHandle( writepipe ); CloseHandle( thread ); - if (!count) break; /* don't try message mode if byte mode already doesn't work */ } GetTempPathA( MAX_PATH, path ); diff --git a/dlls/ntdll/file.c b/dlls/ntdll/file.c index 0381e558ff6..f0f53610854 100644 --- a/dlls/ntdll/file.c +++ b/dlls/ntdll/file.c @@ -438,7 +438,7 @@ static NTSTATUS irp_completion( void *user, IO_STATUS_BLOCK *io, NTSTATUS status { req->user_arg = wine_server_client_ptr( async ); wine_server_set_reply( req, async->buffer, async->size ); - status = wine_server_call( req ); + status = virtual_locked_server_call( req ); information = reply->size; } SERVER_END_REQ; @@ -577,7 +577,7 @@ static NTSTATUS server_read_file( HANDLE handle, HANDLE event, PIO_APC_ROUTINE a req->async = server_async( handle, &async->io, event, apc, apc_context, io ); req->pos = offset ? offset->QuadPart : 0; wine_server_set_reply( req, buffer, size ); - status = wine_server_call( req ); + status = virtual_locked_server_call( req ); wait_handle = wine_server_ptr_handle( reply->wait ); options = reply->options; if (wait_handle && status != STATUS_PENDING) @@ -1540,7 +1540,7 @@ static NTSTATUS server_ioctl_file( HANDLE handle, HANDLE event, if ((code & 3) != METHOD_BUFFERED) wine_server_add_data( req, out_buffer, out_size ); wine_server_set_reply( req, out_buffer, out_size ); - status = wine_server_call( req ); + status = virtual_locked_server_call( req ); wait_handle = wine_server_ptr_handle( reply->wait ); options = reply->options; if (wait_handle && status != STATUS_PENDING) diff --git a/dlls/ntdll/ntdll_misc.h b/dlls/ntdll/ntdll_misc.h index 8dae676d079..aad45464a9d 100644 --- a/dlls/ntdll/ntdll_misc.h +++ b/dlls/ntdll/ntdll_misc.h @@ -88,6 +88,7 @@ extern void DECLSPEC_NORETURN abort_thread( int status ) DECLSPEC_HIDDEN; extern void DECLSPEC_NORETURN terminate_thread( int status ) DECLSPEC_HIDDEN; extern void DECLSPEC_NORETURN exit_thread( int status ) DECLSPEC_HIDDEN; extern sigset_t server_block_set DECLSPEC_HIDDEN; +extern unsigned int server_call_unlocked( void *req_ptr ) DECLSPEC_HIDDEN; extern void server_enter_uninterrupted_section( RTL_CRITICAL_SECTION *cs, sigset_t *sigset ) DECLSPEC_HIDDEN; extern void server_leave_uninterrupted_section( RTL_CRITICAL_SECTION *cs, sigset_t *sigset ) DECLSPEC_HIDDEN; extern unsigned int server_select( const select_op_t *select_op, data_size_t size, @@ -169,6 +170,7 @@ extern void virtual_clear_thread_stack(void) DECLSPEC_HIDDEN; extern BOOL virtual_handle_stack_fault( void *addr ) DECLSPEC_HIDDEN; extern BOOL virtual_is_valid_code_address( const void *addr, SIZE_T size ) DECLSPEC_HIDDEN; extern NTSTATUS virtual_handle_fault( LPCVOID addr, DWORD err, BOOL on_signal_stack ) DECLSPEC_HIDDEN; +extern unsigned int virtual_locked_server_call( void *req_ptr ) DECLSPEC_HIDDEN; extern BOOL virtual_check_buffer_for_read( const void *ptr, SIZE_T size ) DECLSPEC_HIDDEN; extern BOOL virtual_check_buffer_for_write( void *ptr, SIZE_T size ) DECLSPEC_HIDDEN; extern SIZE_T virtual_uninterrupted_read_memory( const void *addr, void *buffer, SIZE_T size ) DECLSPEC_HIDDEN; diff --git a/dlls/ntdll/server.c b/dlls/ntdll/server.c index c3b878e35ad..1e84fbf418e 100644 --- a/dlls/ntdll/server.c +++ b/dlls/ntdll/server.c @@ -277,6 +277,19 @@ static inline unsigned int wait_reply( struct __server_request_info *req ) } +/*********************************************************************** + * server_call_unlocked + */ +unsigned int server_call_unlocked( void *req_ptr ) +{ + struct __server_request_info * const req = req_ptr; + unsigned int ret; + + if ((ret = send_request( req ))) return ret; + return wait_reply( req ); +} + + /*********************************************************************** * wine_server_call (NTDLL.@) * @@ -301,13 +314,11 @@ static inline unsigned int wait_reply( struct __server_request_info *req ) */ unsigned int wine_server_call( void *req_ptr ) { - struct __server_request_info * const req = req_ptr; sigset_t old_set; unsigned int ret; pthread_sigmask( SIG_BLOCK, &server_block_set, &old_set ); - ret = send_request( req ); - if (!ret) ret = wait_reply( req ); + ret = server_call_unlocked( req_ptr ); pthread_sigmask( SIG_SETMASK, &old_set, NULL ); return ret; } diff --git a/dlls/ntdll/virtual.c b/dlls/ntdll/virtual.c index d36dd7d8b05..68a61b55c7e 100644 --- a/dlls/ntdll/virtual.c +++ b/dlls/ntdll/virtual.c @@ -1838,6 +1838,31 @@ static NTSTATUS check_write_access( void *base, size_t size, BOOL *has_write_wat } +/*********************************************************************** + * virtual_locked_server_call + */ +unsigned int virtual_locked_server_call( void *req_ptr ) +{ + struct __server_request_info * const req = req_ptr; + sigset_t sigset; + void *addr = req->reply_data; + data_size_t size = req->u.req.request_header.reply_size; + BOOL has_write_watch = FALSE; + unsigned int ret = STATUS_ACCESS_VIOLATION; + + if (!size) return wine_server_call( req_ptr ); + + server_enter_uninterrupted_section( &csVirtual, &sigset ); + if (!(ret = check_write_access( addr, size, &has_write_watch ))) + { + ret = server_call_unlocked( req ); + if (has_write_watch) update_write_watches( addr, size, wine_server_reply_size( req )); + } + server_leave_uninterrupted_section( &csVirtual, &sigset ); + return ret; +} + + /*********************************************************************** * virtual_is_valid_code_address