From 34fea20cd317418442cef6eb5d58e4283fac5197 Mon Sep 17 00:00:00 2001 From: Paul Gofman Date: Mon, 20 Sep 2021 19:17:20 +0300 Subject: [PATCH] winhttp: Limit recursion for synchronous callback calls. Fixes a regression in Hitman 2, Death Stranding introduced by commit be5acd1c07e093c3b4fe079bff3db74f300ea83b. Signed-off-by: Paul Gofman Signed-off-by: Hans Leidekker Signed-off-by: Alexandre Julliard --- dlls/winhttp/request.c | 11 ++- dlls/winhttp/session.c | 4 +- dlls/winhttp/tests/notification.c | 156 ++++++++++++++++++++++++++++++ dlls/winhttp/winhttp_private.h | 1 + 4 files changed, 167 insertions(+), 5 deletions(-) diff --git a/dlls/winhttp/request.c b/dlls/winhttp/request.c index f2cb24d4486..ed3b41a3155 100644 --- a/dlls/winhttp/request.c +++ b/dlls/winhttp/request.c @@ -2825,6 +2825,11 @@ static DWORD query_data_ready( struct request *request ) return count; } +static BOOL skip_async_queue( struct request *request ) +{ + return request->hdr.recursion_count < 3 && (end_of_read_data( request ) || query_data_ready( request )); +} + static DWORD query_data_available( struct request *request, DWORD *available, BOOL async ) { DWORD ret = ERROR_SUCCESS, count = 0; @@ -2889,8 +2894,7 @@ BOOL WINAPI WinHttpQueryDataAvailable( HINTERNET hrequest, LPDWORD available ) return FALSE; } - if ((async = request->connect->hdr.flags & WINHTTP_FLAG_ASYNC) && !end_of_read_data( request ) - && !query_data_ready( request )) + if ((async = request->connect->hdr.flags & WINHTTP_FLAG_ASYNC) && !skip_async_queue( request )) { struct query_data *q; @@ -2947,8 +2951,7 @@ BOOL WINAPI WinHttpReadData( HINTERNET hrequest, LPVOID buffer, DWORD to_read, L return FALSE; } - if ((async = request->connect->hdr.flags & WINHTTP_FLAG_ASYNC) && !end_of_read_data( request ) - && !query_data_ready( request )) + if ((async = request->connect->hdr.flags & WINHTTP_FLAG_ASYNC) && !skip_async_queue( request )) { struct read_data *r; diff --git a/dlls/winhttp/session.c b/dlls/winhttp/session.c index 659a1ec8cee..3c8cb0fa992 100644 --- a/dlls/winhttp/session.c +++ b/dlls/winhttp/session.c @@ -48,8 +48,10 @@ void send_callback( struct object_header *hdr, DWORD status, void *info, DWORD b { if (hdr->callback && (hdr->notify_mask & status)) { - TRACE("%p, 0x%08x, %p, %u\n", hdr, status, info, buflen); + TRACE("%p, 0x%08x, %p, %u, %u\n", hdr, status, info, buflen, hdr->recursion_count); + InterlockedIncrement( &hdr->recursion_count ); hdr->callback( hdr->handle, hdr->context, status, info, buflen ); + InterlockedDecrement( &hdr->recursion_count ); TRACE("returning from 0x%08x callback\n", status); } } diff --git a/dlls/winhttp/tests/notification.c b/dlls/winhttp/tests/notification.c index 4cfffe6687e..e6e7e0b21e7 100644 --- a/dlls/winhttp/tests/notification.c +++ b/dlls/winhttp/tests/notification.c @@ -1212,6 +1212,161 @@ static void test_persistent_connection(int port) CloseHandle( info.wait ); } +struct test_recursion_context +{ + HANDLE request; + HANDLE wait; + LONG recursion_count, max_recursion_query, max_recursion_read; + BOOL read_from_callback; + BOOL have_sync_callback; +}; + +/* The limit is 128 before Win7 and 3 on newer Windows. */ +#define TEST_RECURSION_LIMIT 128 + +static void CALLBACK test_recursion_callback( HINTERNET handle, DWORD_PTR context_ptr, + DWORD status, void *buffer, DWORD buflen ) +{ + struct test_recursion_context *context = (struct test_recursion_context *)context_ptr; + DWORD err; + BOOL ret; + BYTE b; + + switch (status) + { + case WINHTTP_CALLBACK_STATUS_SENDREQUEST_COMPLETE: + case WINHTTP_CALLBACK_STATUS_HEADERS_AVAILABLE: + SetEvent( context->wait ); + break; + + case WINHTTP_CALLBACK_STATUS_DATA_AVAILABLE: + if (!context->read_from_callback) + { + SetEvent( context->wait ); + break; + } + + if (!*(DWORD *)buffer) + { + SetEvent( context->wait ); + break; + } + + ok(context->recursion_count < TEST_RECURSION_LIMIT, + "Got unexpected context->recursion_count %u, thread %#x.\n", + context->recursion_count, GetCurrentThreadId()); + context->max_recursion_query = max( context->max_recursion_query, context->recursion_count ); + InterlockedIncrement( &context->recursion_count ); + ret = WinHttpReadData( context->request, &b, 1, NULL ); + err = GetLastError(); + ok(ret, "Failed to read data, GetLastError() %u.\n", err); + ok(err == ERROR_SUCCESS || err == ERROR_IO_PENDING, "Got unexpected err %u.\n", err); + if (err == ERROR_SUCCESS) + context->have_sync_callback = TRUE; + InterlockedDecrement( &context->recursion_count ); + break; + + case WINHTTP_CALLBACK_STATUS_READ_COMPLETE: + if (!buflen) + { + SetEvent( context->wait ); + break; + } + ok(context->recursion_count < TEST_RECURSION_LIMIT, + "Got unexpected context->recursion_count %u, thread %#x.\n", + context->recursion_count, GetCurrentThreadId()); + context->max_recursion_read = max( context->max_recursion_read, context->recursion_count ); + context->read_from_callback = TRUE; + InterlockedIncrement( &context->recursion_count ); + ret = WinHttpQueryDataAvailable( context->request, NULL ); + err = GetLastError(); + ok(ret, "Failed to query data available, GetLastError() %u.\n", err); + ok(err == ERROR_SUCCESS || err == ERROR_IO_PENDING, "Got unexpected err %u.\n", err); + if (err == ERROR_SUCCESS) + context->have_sync_callback = TRUE; + InterlockedDecrement( &context->recursion_count ); + break; + } +} + +static void test_recursion(void) +{ + struct test_recursion_context context; + HANDLE session, connection, request; + DWORD size, status, err; + BOOL ret; + BYTE b; + + memset( &context, 0, sizeof(context) ); + + context.wait = CreateEventW( NULL, FALSE, FALSE, NULL ); + + session = WinHttpOpen( L"winetest", 0, NULL, NULL, WINHTTP_FLAG_ASYNC ); + ok(!!session, "Failed to open session, GetLastError() %u.\n", GetLastError()); + + WinHttpSetStatusCallback( session, test_recursion_callback, WINHTTP_CALLBACK_FLAG_ALL_NOTIFICATIONS, 0 ); + + connection = WinHttpConnect( session, L"test.winehq.org", 0, 0 ); + ok(!!connection, "Failed to open a connection, GetLastError() %u.\n", GetLastError()); + + request = WinHttpOpenRequest( connection, NULL, L"/tests/hello.html", NULL, NULL, NULL, 0 ); + ok(!!request, "Failed to open a request, GetLastError() %u.\n", GetLastError()); + + context.request = request; + ret = WinHttpSendRequest( request, NULL, 0, NULL, 0, 0, (DWORD_PTR)&context ); + err = GetLastError(); + if (!ret && (err == ERROR_WINHTTP_CANNOT_CONNECT || err == ERROR_WINHTTP_TIMEOUT)) + { + skip("Connection failed, skipping\n"); + WinHttpSetStatusCallback( session, NULL, WINHTTP_CALLBACK_FLAG_ALL_NOTIFICATIONS, 0 ); + WinHttpCloseHandle( request ); + WinHttpCloseHandle( connection ); + WinHttpCloseHandle( session ); + CloseHandle( context.wait ); + return; + } + ok(ret, "Failed to send request, GetLastError() %u.\n", GetLastError()); + + WaitForSingleObject( context.wait, INFINITE ); + + ret = WinHttpReceiveResponse( request, NULL ); + ok(ret, "Failed to receive response, GetLastError() %u.\n", GetLastError()); + + WaitForSingleObject( context.wait, INFINITE ); + + size = sizeof(status); + ret = WinHttpQueryHeaders( request, WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER, NULL, + &status, &size, NULL ); + ok(ret, "Request failed, GetLastError() %u.\n", GetLastError()); + ok(status == 200, "Request failed unexpectedly, status %u.\n", status); + + ret = WinHttpQueryDataAvailable( request, NULL ); + ok(ret, "Failed to query data available, GetLastError() %u.\n", GetLastError()); + + WaitForSingleObject( context.wait, INFINITE ); + + ret = WinHttpReadData( request, &b, 1, NULL ); + ok(ret, "Failed to read data, GetLastError() %u.\n", GetLastError()); + + WaitForSingleObject( context.wait, INFINITE ); + if (context.have_sync_callback) + { + ok(context.max_recursion_query >= 2, "Got unexpected max_recursion_query %u.\n", context.max_recursion_query); + ok(context.max_recursion_read >= 2, "Got unexpected max_recursion_read %u.\n", context.max_recursion_read); + } + else + { + skip("No sync callbacks.\n"); + } + + WinHttpSetStatusCallback( session, NULL, WINHTTP_CALLBACK_FLAG_ALL_NOTIFICATIONS, 0 ); + + WinHttpCloseHandle( request ); + WinHttpCloseHandle( connection ); + WinHttpCloseHandle( session ); + CloseHandle( context.wait ); +} + START_TEST (notification) { HMODULE mod = GetModuleHandleA( "winhttp.dll" ); @@ -1230,6 +1385,7 @@ START_TEST (notification) test_redirect(); test_async(); test_websocket(); + test_recursion(); si.event = CreateEventW( NULL, 0, 0, NULL ); si.port = 7533; diff --git a/dlls/winhttp/winhttp_private.h b/dlls/winhttp/winhttp_private.h index 069951d4811..291a38e7bdd 100644 --- a/dlls/winhttp/winhttp_private.h +++ b/dlls/winhttp/winhttp_private.h @@ -49,6 +49,7 @@ struct object_header LONG refs; WINHTTP_STATUS_CALLBACK callback; DWORD notify_mask; + LONG recursion_count; struct list entry; };