From 51d348c8c274b6930d9a5007316beda9062204d2 Mon Sep 17 00:00:00 2001 From: Alexandre Julliard Date: Fri, 22 Sep 2017 15:04:34 +0200 Subject: [PATCH] ntdll: Add helpers to lock the user output buffer during file reads. Signed-off-by: Alexandre Julliard --- dlls/kernel32/tests/virtual.c | 12 --------- dlls/ntdll/file.c | 6 ++--- dlls/ntdll/ntdll_misc.h | 2 ++ dlls/ntdll/virtual.c | 50 +++++++++++++++++++++++++++++++++++ 4 files changed, 55 insertions(+), 15 deletions(-) diff --git a/dlls/kernel32/tests/virtual.c b/dlls/kernel32/tests/virtual.c index ec2489f97a1..d7d5124d625 100644 --- a/dlls/kernel32/tests/virtual.c +++ b/dlls/kernel32/tests/virtual.c @@ -1840,13 +1840,10 @@ static DWORD CALLBACK read_pipe( void *arg ) "%u: ConnectNamedPipe failed %u\n", args->index, GetLastError() ); success = ReadFile( args->pipe, args->base, args->size, &num_bytes, NULL ); - todo_wine_if (!args->index) - { ok( success, "%u: ReadFile failed %u\n", args->index, GetLastError() ); ok( num_bytes == sizeof(testdata), "%u: wrong number of bytes read %u\n", args->index, num_bytes ); ok( !memcmp( args->base, testdata, sizeof(testdata)), "%u: didn't receive expected data\n", args->index ); - } return 0; } @@ -2056,22 +2053,16 @@ static void test_write_watch(void) num_bytes = 0; success = GetOverlappedResult( readpipe, &overlapped, &num_bytes, TRUE ); - todo_wine_if (!i) - { ok( success, "%u: GetOverlappedResult failed %u\n", i, GetLastError() ); ok( num_bytes == sizeof(testdata), "%u: wrong number of bytes read %u\n", i, num_bytes ); ok( !memcmp( base, testdata, sizeof(testdata)), "%u: didn't receive expected data\n", i ); - } count = 64; memset( results, 0, sizeof(results) ); ret = pGetWriteWatch( WRITE_WATCH_FLAG_RESET, base, size, results, &count, &pagesize ); ok( !ret, "%u: GetWriteWatch failed %u\n", i, GetLastError() ); - todo_wine_if (!i) - { ok( count == 1, "%u: wrong count %lu\n", i, count ); ok( results[0] == base, "%u: wrong result %p\n", i, results[0] ); - } CloseHandle( readpipe ); CloseHandle( writepipe ); @@ -2120,11 +2111,8 @@ static void test_write_watch(void) memset( results, 0, sizeof(results) ); ret = pGetWriteWatch( WRITE_WATCH_FLAG_RESET, base, size, results, &count, &pagesize ); ok( !ret, "%u: GetWriteWatch failed %u\n", i, GetLastError() ); - todo_wine_if (!i) - { ok( count == 1, "%u: wrong count %lu\n", i, count ); ok( results[0] == base, "%u: wrong result %p\n", i, results[0] ); - } CloseHandle( readpipe ); CloseHandle( writepipe ); diff --git a/dlls/ntdll/file.c b/dlls/ntdll/file.c index f0f53610854..28a69ff832e 100644 --- a/dlls/ntdll/file.c +++ b/dlls/ntdll/file.c @@ -517,7 +517,7 @@ static NTSTATUS FILE_AsyncReadService( void *user, IO_STATUS_BLOCK *iosb, NTSTAT &needs_close, NULL, NULL ))) break; - result = read(fd, &fileio->buffer[fileio->already], fileio->count - fileio->already); + result = virtual_locked_read(fd, &fileio->buffer[fileio->already], fileio->count-fileio->already); if (needs_close) close( fd ); if (result < 0) @@ -869,7 +869,7 @@ NTSTATUS WINAPI NtReadFile(HANDLE hFile, HANDLE hEvent, if (offset && offset->QuadPart != FILE_USE_FILE_POINTER_POSITION) { /* async I/O doesn't make sense on regular files */ - while ((result = pread( unix_handle, buffer, length, offset->QuadPart )) == -1) + while ((result = virtual_locked_pread( unix_handle, buffer, length, offset->QuadPart )) == -1) { if (errno != EINTR) { @@ -911,7 +911,7 @@ NTSTATUS WINAPI NtReadFile(HANDLE hFile, HANDLE hEvent, for (;;) { - if ((result = read( unix_handle, (char *)buffer + total, length - total )) >= 0) + if ((result = virtual_locked_read( unix_handle, (char *)buffer + total, length - total )) >= 0) { total += result; if (!result || total == length) diff --git a/dlls/ntdll/ntdll_misc.h b/dlls/ntdll/ntdll_misc.h index aad45464a9d..907bbdd2d95 100644 --- a/dlls/ntdll/ntdll_misc.h +++ b/dlls/ntdll/ntdll_misc.h @@ -171,6 +171,8 @@ extern BOOL virtual_handle_stack_fault( void *addr ) DECLSPEC_HIDDEN; extern BOOL virtual_is_valid_code_address( const void *addr, SIZE_T size ) DECLSPEC_HIDDEN; extern NTSTATUS virtual_handle_fault( LPCVOID addr, DWORD err, BOOL on_signal_stack ) DECLSPEC_HIDDEN; extern unsigned int virtual_locked_server_call( void *req_ptr ) DECLSPEC_HIDDEN; +extern ssize_t virtual_locked_read( int fd, void *addr, size_t size ) DECLSPEC_HIDDEN; +extern ssize_t virtual_locked_pread( int fd, void *addr, size_t size, off_t offset ) DECLSPEC_HIDDEN; extern BOOL virtual_check_buffer_for_read( const void *ptr, SIZE_T size ) DECLSPEC_HIDDEN; extern BOOL virtual_check_buffer_for_write( void *ptr, SIZE_T size ) DECLSPEC_HIDDEN; extern SIZE_T virtual_uninterrupted_read_memory( const void *addr, void *buffer, SIZE_T size ) DECLSPEC_HIDDEN; diff --git a/dlls/ntdll/virtual.c b/dlls/ntdll/virtual.c index 68a61b55c7e..2cdcca8a599 100644 --- a/dlls/ntdll/virtual.c +++ b/dlls/ntdll/virtual.c @@ -1863,6 +1863,56 @@ unsigned int virtual_locked_server_call( void *req_ptr ) } +/*********************************************************************** + * virtual_locked_read + */ +ssize_t virtual_locked_read( int fd, void *addr, size_t size ) +{ + sigset_t sigset; + BOOL has_write_watch = FALSE; + int err = EFAULT; + + ssize_t ret = read( fd, addr, size ); + if (ret != -1 || errno != EFAULT) return ret; + + server_enter_uninterrupted_section( &csVirtual, &sigset ); + if (!check_write_access( addr, size, &has_write_watch )) + { + ret = read( fd, addr, size ); + err = errno; + if (has_write_watch) update_write_watches( addr, size, max( 0, ret )); + } + server_leave_uninterrupted_section( &csVirtual, &sigset ); + errno = err; + return ret; +} + + +/*********************************************************************** + * virtual_locked_pread + */ +ssize_t virtual_locked_pread( int fd, void *addr, size_t size, off_t offset ) +{ + sigset_t sigset; + BOOL has_write_watch = FALSE; + int err = EFAULT; + + ssize_t ret = pread( fd, addr, size, offset ); + if (ret != -1 || errno != EFAULT) return ret; + + server_enter_uninterrupted_section( &csVirtual, &sigset ); + if (!check_write_access( addr, size, &has_write_watch )) + { + ret = pread( fd, addr, size, offset ); + err = errno; + if (has_write_watch) update_write_watches( addr, size, max( 0, ret )); + } + server_leave_uninterrupted_section( &csVirtual, &sigset ); + errno = err; + return ret; +} + + /*********************************************************************** * virtual_is_valid_code_address