From bb3097f01e532583c8c78c9daf64ff29f8eb255c Mon Sep 17 00:00:00 2001 From: Hans Leidekker Date: Mon, 25 Mar 2013 13:26:37 +0100 Subject: [PATCH] winhttp: Add a read-ahead buffer to allow WinHttpQueryDataAvailable to return the right values in chunked mode. This is a port of wininet commit 3d02c42b39c7346a97c41974418a6d01a29f9b81. --- dlls/winhttp/request.c | 385 ++++++++++++++++++++------------- dlls/winhttp/tests/winhttp.c | 78 +++++++ dlls/winhttp/winhttp_private.h | 5 + 3 files changed, 320 insertions(+), 148 deletions(-) diff --git a/dlls/winhttp/request.c b/dlls/winhttp/request.c index e2c373458f6..f14f22b257f 100644 --- a/dlls/winhttp/request.c +++ b/dlls/winhttp/request.c @@ -944,7 +944,7 @@ static BOOL open_connection( request_t *request ) struct sockaddr *saddr; DWORD len; - if (netconn_connected( &request->netconn )) return TRUE; + if (netconn_connected( &request->netconn )) goto done; connect = request->connect; port = connect->serverport ? connect->serverport : (request->hdr.flags & WINHTTP_FLAG_SECURE ? 443 : 80); @@ -1002,6 +1002,9 @@ static BOOL open_connection( request_t *request ) send_callback( &request->hdr, WINHTTP_CALLBACK_STATUS_CONNECTED_TO_SERVER, addressW, strlenW(addressW) + 1 ); +done: + request->read_pos = request->read_size = 0; + request->read_chunked = FALSE; heap_free( addressW ); return TRUE; } @@ -1473,6 +1476,180 @@ static BOOL handle_authorization( request_t *request, DWORD status ) return FALSE; } +/* set the request content length based on the headers */ +static DWORD set_content_length( request_t *request ) +{ + WCHAR encoding[20]; + DWORD buflen; + + buflen = sizeof(request->content_length); + if (!query_headers( request, WINHTTP_QUERY_CONTENT_LENGTH|WINHTTP_QUERY_FLAG_NUMBER, + NULL, &request->content_length, &buflen, NULL )) + request->content_length = ~0u; + + buflen = sizeof(encoding); + if (query_headers( request, WINHTTP_QUERY_TRANSFER_ENCODING, NULL, encoding, &buflen, NULL ) && + !strcmpiW( encoding, chunkedW )) + { + request->content_length = ~0u; + request->read_chunked = TRUE; + } + return request->content_length; +} + +/* read some more data into the read buffer */ +static BOOL read_more_data( request_t *request, int maxlen ) +{ + int len; + + if (request->read_size && request->read_pos) + { + /* move existing data to the start of the buffer */ + memmove( request->read_buf, request->read_buf + request->read_pos, request->read_size ); + request->read_pos = 0; + } + if (maxlen == -1) maxlen = sizeof(request->read_buf); + if (!netconn_recv( &request->netconn, request->read_buf + request->read_size, + maxlen - request->read_size, 0, &len )) return FALSE; + request->read_size += len; + return TRUE; +} + +/* remove some amount of data from the read buffer */ +static void remove_data( request_t *request, int count ) +{ + if (!(request->read_size -= count)) request->read_pos = 0; + else request->read_pos += count; +} + +static BOOL read_line( request_t *request, char *buffer, DWORD *len ) +{ + int count, bytes_read, pos = 0; + + for (;;) + { + char *eol = memchr( request->read_buf + request->read_pos, '\n', request->read_size ); + if (eol) + { + count = eol - (request->read_buf + request->read_pos); + bytes_read = count + 1; + } + else count = bytes_read = request->read_size; + + count = min( count, *len - pos ); + memcpy( buffer + pos, request->read_buf + request->read_pos, count ); + pos += count; + remove_data( request, bytes_read ); + if (eol) break; + + if (!read_more_data( request, -1 )) return FALSE; + if (!request->read_size) + { + *len = 0; + TRACE("returning empty string\n"); + return FALSE; + } + } + if (pos < *len) + { + if (pos && buffer[pos - 1] == '\r') pos--; + *len = pos + 1; + } + buffer[*len - 1] = 0; + TRACE("returning %s\n", debugstr_a(buffer)); + return TRUE; +} + +/* discard data contents until we reach end of line */ +static BOOL discard_eol( request_t *request ) +{ + do + { + char *eol = memchr( request->read_buf + request->read_pos, '\n', request->read_size ); + if (eol) + { + remove_data( request, (eol + 1) - (request->read_buf + request->read_pos) ); + break; + } + request->read_pos = request->read_size = 0; /* discard everything */ + if (!read_more_data( request, -1 )) return FALSE; + } while (request->read_size); + return TRUE; +} + +/* read the size of the next chunk */ +static BOOL start_next_chunk( request_t *request ) +{ + DWORD chunk_size = 0; + + if (!request->content_length) return TRUE; + if (request->content_length == request->content_read) + { + /* read terminator for the previous chunk */ + if (!discard_eol( request )) return FALSE; + request->content_length = ~0u; + request->content_read = 0; + } + for (;;) + { + while (request->read_size) + { + char ch = request->read_buf[request->read_pos]; + if (ch >= '0' && ch <= '9') chunk_size = chunk_size * 16 + ch - '0'; + else if (ch >= 'a' && ch <= 'f') chunk_size = chunk_size * 16 + ch - 'a' + 10; + else if (ch >= 'A' && ch <= 'F') chunk_size = chunk_size * 16 + ch - 'A' + 10; + else if (ch == ';' || ch == '\r' || ch == '\n') + { + TRACE("reading %u byte chunk\n", chunk_size); + request->content_length = chunk_size; + request->content_read = 0; + if (!discard_eol( request )) return FALSE; + return TRUE; + } + remove_data( request, 1 ); + } + if (!read_more_data( request, -1 )) return FALSE; + if (!request->read_size) + { + request->content_length = request->content_read = 0; + return TRUE; + } + } +} + +/* return the size of data available to be read immediately */ +static DWORD get_available_data( request_t *request ) +{ + if (request->read_chunked && + (request->content_length == ~0u || request->content_length == request->content_read)) + return 0; + return min( request->read_size, request->content_length - request->content_read ); +} + +/* check if we have reached the end of the data to read */ +static BOOL end_of_read_data( request_t *request ) +{ + if (request->read_chunked) return (request->content_length == 0); + if (request->content_length == ~0u) return FALSE; + return (request->content_length == request->content_read); +} + +static BOOL refill_buffer( request_t *request ) +{ + int len = sizeof(request->read_buf); + + if (request->read_chunked && + (request->content_length == ~0u || request->content_length == request->content_read)) + { + if (!start_next_chunk( request )) return FALSE; + } + if (request->content_length != ~0u) len = min( len, request->content_length - request->content_read ); + if (len <= request->read_size) return TRUE; + if (!read_more_data( request, len )) return FALSE; + if (!request->read_size) request->content_length = request->content_read = 0; + return TRUE; +} + #define MAX_REPLY_LEN 1460 #define INITIAL_HEADER_BUFFER_LEN 512 @@ -1494,7 +1671,7 @@ static BOOL read_reply( request_t *request ) do { buflen = MAX_REPLY_LEN; - if (!netconn_get_next_line( &request->netconn, buffer, &buflen )) return FALSE; + if (!read_line( request, buffer, &buflen )) return FALSE; received_len += buflen; /* first line should look like 'HTTP/1.x nnn OK' where nnn is the status code */ @@ -1545,7 +1722,7 @@ static BOOL read_reply( request_t *request ) header_t *header; buflen = MAX_REPLY_LEN; - if (!netconn_get_next_line( &request->netconn, buffer, &buflen )) goto end; + if (!read_line( request, buffer, &buflen )) goto end; received_len += buflen; if (!*buffer) break; @@ -1576,101 +1753,6 @@ end: return TRUE; } -static BOOL receive_data( request_t *request, void *buffer, DWORD size, DWORD *read, BOOL async ) -{ - DWORD to_read; - int bytes_read; - - to_read = min( size, request->content_length - request->content_read ); - if (!netconn_recv( &request->netconn, buffer, to_read, async ? 0 : MSG_WAITALL, &bytes_read )) - { - if (bytes_read != to_read) - { - ERR("not all data received %d/%d\n", bytes_read, to_read); - } - /* always return success, even if the network layer returns an error */ - *read = 0; - return TRUE; - } - request->content_read += bytes_read; - *read = bytes_read; - return TRUE; -} - -static DWORD get_chunk_size( const char *buffer ) -{ - const char *p; - DWORD size = 0; - - for (p = buffer; *p; p++) - { - if (*p >= '0' && *p <= '9') size = size * 16 + *p - '0'; - else if (*p >= 'a' && *p <= 'f') size = size * 16 + *p - 'a' + 10; - else if (*p >= 'A' && *p <= 'F') size = size * 16 + *p - 'A' + 10; - else if (*p == ';') break; - } - return size; -} - -static BOOL receive_data_chunked( request_t *request, void *buffer, DWORD size, DWORD *read, BOOL async ) -{ - char reply[MAX_REPLY_LEN], *p = buffer; - DWORD buflen, to_read, to_write = size; - int bytes_read; - - *read = 0; - for (;;) - { - if (*read == size) break; - - if (request->content_length == ~0u) /* new chunk */ - { - buflen = sizeof(reply); - if (!netconn_get_next_line( &request->netconn, reply, &buflen )) break; - - if (!(request->content_length = get_chunk_size( reply ))) - { - /* zero sized chunk marks end of transfer; read any trailing headers and return */ - read_reply( request ); - break; - } - } - to_read = min( to_write, request->content_length - request->content_read ); - - if (!netconn_recv( &request->netconn, p, to_read, async ? 0 : MSG_WAITALL, &bytes_read )) - { - if (bytes_read != to_read) - { - ERR("Not all data received %d/%d\n", bytes_read, to_read); - } - /* always return success, even if the network layer returns an error */ - *read = 0; - break; - } - if (!bytes_read) break; - - request->content_read += bytes_read; - to_write -= bytes_read; - *read += bytes_read; - p += bytes_read; - - if (request->content_read == request->content_length) /* chunk complete */ - { - request->content_read = 0; - request->content_length = ~0u; - - buflen = sizeof(reply); - if (!netconn_get_next_line( &request->netconn, reply, &buflen )) - { - ERR("Malformed chunk\n"); - *read = 0; - break; - } - } - } - return TRUE; -} - static void finished_reading( request_t *request ) { static const WCHAR closeW[] = {'c','l','o','s','e',0}; @@ -1686,31 +1768,40 @@ static void finished_reading( request_t *request ) if (!strcmpiW( connection, closeW )) close = TRUE; } else if (!strcmpW( request->version, http1_0 )) close = TRUE; - if (close) close_connection( request ); - request->content_length = ~0u; - request->content_read = 0; } -static BOOL read_data( request_t *request, void *buffer, DWORD to_read, DWORD *read, BOOL async ) +static BOOL read_data( request_t *request, void *buffer, DWORD size, DWORD *read, BOOL async ) { - static const WCHAR chunked[] = {'c','h','u','n','k','e','d',0}; + BOOL ret = TRUE; + int len, bytes_read = 0; - BOOL ret; - WCHAR encoding[20]; - DWORD num_bytes, buflen = sizeof(encoding); - - if (query_headers( request, WINHTTP_QUERY_TRANSFER_ENCODING, NULL, encoding, &buflen, NULL ) && - !strcmpiW( encoding, chunked )) + if (request->read_chunked && + (request->content_length == ~0u || request->content_length == request->content_read)) { - ret = receive_data_chunked( request, buffer, to_read, &num_bytes, async ); + if (!start_next_chunk( request )) goto done; } - else - ret = receive_data( request, buffer, to_read, &num_bytes, async ); + if (request->content_length != ~0u) size = min( size, request->content_length - request->content_read ); + if (request->read_size) + { + bytes_read = min( request->read_size, size ); + memcpy( buffer, request->read_buf + request->read_pos, bytes_read ); + remove_data( request, bytes_read ); + } + if (size > bytes_read && (!bytes_read || !async)) + { + if ((ret = netconn_recv( &request->netconn, (char *)buffer + bytes_read, size - bytes_read, + async ? 0 : MSG_WAITALL, &len ))) + bytes_read += len; + } + +done: + request->content_read += bytes_read; + TRACE( "retrieved %u bytes (%u/%u)\n", bytes_read, request->content_read, request->content_length ); if (async) { - if (ret) send_callback( &request->hdr, WINHTTP_CALLBACK_STATUS_READ_COMPLETE, buffer, num_bytes ); + if (ret) send_callback( &request->hdr, WINHTTP_CALLBACK_STATUS_READ_COMPLETE, buffer, bytes_read ); else { WINHTTP_ASYNC_RESULT result; @@ -1719,11 +1810,8 @@ static BOOL read_data( request_t *request, void *buffer, DWORD to_read, DWORD *r send_callback( &request->hdr, WINHTTP_CALLBACK_STATUS_REQUEST_ERROR, &result, sizeof(result) ); } } - if (ret) - { - if (read) *read = num_bytes; - if (!num_bytes) finished_reading( request ); - } + if (read) *read = bytes_read; + if (!bytes_read && request->content_read == request->content_length) finished_reading( request ); return ret; } @@ -1733,7 +1821,7 @@ static void drain_content( request_t *request ) DWORD bytes_read; char buffer[2048]; - if (!request->content_length) + if (request->content_length == ~0u) { finished_reading( request ); return; @@ -1837,6 +1925,8 @@ static BOOL handle_redirect( request_t *request, DWORD status ) netconn_close( &request->netconn ); if (!(ret = netconn_init( &request->netconn ))) goto end; + request->read_pos = request->read_size = 0; + request->read_chunked = FALSE; } if (!(ret = add_host_header( request, WINHTTP_ADDREQ_FLAG_REPLACE ))) goto end; if (!(ret = open_connection( request ))) goto end; @@ -1885,10 +1975,7 @@ static BOOL receive_response( request_t *request, BOOL async ) query = WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER; if (!(ret = query_headers( request, query, NULL, &status, &size, NULL ))) break; - size = sizeof(DWORD); - query = WINHTTP_QUERY_CONTENT_LENGTH | WINHTTP_QUERY_FLAG_NUMBER; - if (!query_headers( request, query, NULL, &request->content_length, &size, NULL )) - request->content_length = ~0u; + set_content_length( request ); if (!(request->hdr.disable_flags & WINHTTP_DISABLE_COOKIES)) record_cookies( request ); @@ -1978,35 +2065,33 @@ BOOL WINAPI WinHttpReceiveResponse( HINTERNET hrequest, LPVOID reserved ) return ret; } -static BOOL query_data( request_t *request, LPDWORD available, BOOL async ) +static BOOL query_data_available( request_t *request, DWORD *available, BOOL async ) { - BOOL ret; - DWORD num_bytes; + BOOL ret = TRUE; + DWORD count; - if ((ret = netconn_query_data_available( &request->netconn, &num_bytes ))) + if (!(count = get_available_data( request ))) { - if (request->content_read < request->content_length) + if (end_of_read_data( request )) { - if (!num_bytes) - { - char buffer[4096]; - size_t to_read = min( sizeof(buffer), request->content_length - request->content_read ); - - ret = netconn_recv( &request->netconn, buffer, to_read, MSG_PEEK, (int *)&num_bytes ); - if (ret && !num_bytes) WARN("expected more data to be available\n"); - } - } - else if (num_bytes) - { - WARN("extra data available %u\n", num_bytes); - ret = FALSE; + if (available) *available = 0; + return TRUE; } } - TRACE("%u bytes available\n", num_bytes); + refill_buffer( request ); + count = get_available_data( request ); + if (count == sizeof(request->read_buf)) /* check if we have even more pending in the socket */ + { + DWORD extra; + if ((ret = netconn_query_data_available( &request->netconn, &extra ))) + { + count = min( count + extra, request->content_length - request->content_read ); + } + } if (async) { - if (ret) send_callback( &request->hdr, WINHTTP_CALLBACK_STATUS_DATA_AVAILABLE, &num_bytes, sizeof(DWORD) ); + if (ret) send_callback( &request->hdr, WINHTTP_CALLBACK_STATUS_DATA_AVAILABLE, &count, sizeof(count) ); else { WINHTTP_ASYNC_RESULT result; @@ -2015,14 +2100,18 @@ static BOOL query_data( request_t *request, LPDWORD available, BOOL async ) send_callback( &request->hdr, WINHTTP_CALLBACK_STATUS_REQUEST_ERROR, &result, sizeof(result) ); } } - if (ret && available) *available = num_bytes; + if (ret) + { + TRACE("%u bytes available\n", count); + if (available) *available = count; + } return ret; } -static void task_query_data( task_header_t *task ) +static void task_query_data_available( task_header_t *task ) { query_data_t *q = (query_data_t *)task; - query_data( q->hdr.request, q->available, TRUE ); + query_data_available( q->hdr.request, q->available, TRUE ); } /*********************************************************************** @@ -2053,14 +2142,14 @@ BOOL WINAPI WinHttpQueryDataAvailable( HINTERNET hrequest, LPDWORD available ) if (!(q = heap_alloc( sizeof(query_data_t) ))) return FALSE; q->hdr.request = request; - q->hdr.proc = task_query_data; + q->hdr.proc = task_query_data_available; q->available = available; addref_object( &request->hdr ); ret = queue_task( (task_header_t *)q ); } else - ret = query_data( request, available, FALSE ); + ret = query_data_available( request, available, FALSE ); release_object( &request->hdr ); return ret; diff --git a/dlls/winhttp/tests/winhttp.c b/dlls/winhttp/tests/winhttp.c index b03e1085cbb..03de01210cd 100644 --- a/dlls/winhttp/tests/winhttp.c +++ b/dlls/winhttp/tests/winhttp.c @@ -2792,6 +2792,83 @@ static void test_WinHttpGetProxyForUrl(void) WinHttpCloseHandle( session ); } +static void test_chunked_read(void) +{ + static const WCHAR host[] = {'t','e','s','t','.','w','i','n','e','h','q','.','o','r','g',0}; + static const WCHAR verb[] = {'/','t','e','s','t','c','h','u','n','k','e','d',0}; + static const WCHAR chunked[] = {'c','h','u','n','k','e','d',0}; + WCHAR header[32]; + DWORD len; + HINTERNET ses, con = NULL, req = NULL; + BOOL ret; + + trace( "starting chunked read test\n" ); + + ses = WinHttpOpen( test_useragent, 0, NULL, NULL, 0 ); + ok( ses != NULL, "WinHttpOpen failed with error %u\n", GetLastError() ); + if (!ses) goto done; + + con = WinHttpConnect( ses, host, 0, 0 ); + ok( con != NULL, "WinHttpConnect failed with error %u\n", GetLastError() ); + if (!con) goto done; + + req = WinHttpOpenRequest( con, NULL, verb, NULL, NULL, NULL, 0 ); + ok( req != NULL, "WinHttpOpenRequest failed with error %u\n", GetLastError() ); + if (!req) goto done; + + ret = WinHttpSendRequest( req, NULL, 0, NULL, 0, 0, 0 ); + ok( ret, "WinHttpSendRequest failed with error %u\n", GetLastError() ); + + ret = WinHttpReceiveResponse( req, NULL ); + ok( ret, "WinHttpReceiveResponse failed with error %u\n", GetLastError() ); + + header[0] = 0; + len = sizeof(header); + ret = WinHttpQueryHeaders( req, WINHTTP_QUERY_TRANSFER_ENCODING, NULL, header, &len, 0 ); + ok( ret, "failed to get TRANSFER_ENCODING header (error %u)\n", GetLastError() ); + ok( !lstrcmpW( header, chunked ), "wrong transfer encoding %s\n", wine_dbgstr_w(header) ); + trace( "transfer encoding: %s\n", wine_dbgstr_w(header) ); + + header[0] = 0; + len = sizeof(header); + SetLastError( 0xdeadbeef ); + ret = WinHttpQueryHeaders( req, WINHTTP_QUERY_CONTENT_LENGTH, NULL, &header, &len, 0 ); + ok( !ret, "unexpected CONTENT_LENGTH header %s\n", wine_dbgstr_w(header) ); + ok( GetLastError() == ERROR_WINHTTP_HEADER_NOT_FOUND, "wrong error %u\n", GetLastError() ); + + trace( "entering query loop\n" ); + for (;;) + { + len = 0xdeadbeef; + ret = WinHttpQueryDataAvailable( req, &len ); + ok( ret, "WinHttpQueryDataAvailable failed with error %u\n", GetLastError() ); + if (ret) ok( len != 0xdeadbeef, "WinHttpQueryDataAvailable return wrong length\n" ); + trace( "got %u available\n", len ); + if (len) + { + DWORD bytes_read; + char *buf = HeapAlloc( GetProcessHeap(), 0, len + 1 ); + + ret = WinHttpReadData( req, buf, len, &bytes_read ); + + buf[bytes_read] = 0; + trace( "WinHttpReadData -> %d %u\n", ret, bytes_read ); + ok( len == bytes_read, "only got %u of %u available\n", bytes_read, len ); + ok( buf[bytes_read - 1] == '\n', "received partial line '%s'\n", buf ); + + HeapFree( GetProcessHeap(), 0, buf ); + if (!bytes_read) break; + } + if (!len) break; + } + trace( "done\n" ); + +done: + if (req) WinHttpCloseHandle( req ); + if (con) WinHttpCloseHandle( con ); + if (ses) WinHttpCloseHandle( ses ); +} + START_TEST (winhttp) { static const WCHAR basicW[] = {'/','b','a','s','i','c',0}; @@ -2817,6 +2894,7 @@ START_TEST (winhttp) test_WinHttpDetectAutoProxyConfigUrl(); test_WinHttpGetIEProxyConfigForCurrentUser(); test_WinHttpGetProxyForUrl(); + test_chunked_read(); si.event = CreateEvent(NULL, 0, 0, NULL); si.port = 7532; diff --git a/dlls/winhttp/winhttp_private.h b/dlls/winhttp/winhttp_private.h index 963abecdc09..b19aadfec15 100644 --- a/dlls/winhttp/winhttp_private.h +++ b/dlls/winhttp/winhttp_private.h @@ -52,6 +52,7 @@ static const WCHAR headW[] = {'H','E','A','D',0}; static const WCHAR slashW[] = {'/',0}; static const WCHAR http1_0[] = {'H','T','T','P','/','1','.','0',0}; static const WCHAR http1_1[] = {'H','T','T','P','/','1','.','1',0}; +static const WCHAR chunkedW[] = {'c','h','u','n','k','e','d',0}; typedef struct _object_header_t object_header_t; @@ -163,6 +164,10 @@ typedef struct LPWSTR status_text; DWORD content_length; /* total number of bytes to be read (per chunk) */ DWORD content_read; /* bytes read so far */ + BOOL read_chunked; /* are we reading in chunked mode? */ + DWORD read_pos; /* current read position in read_buf */ + DWORD read_size; /* valid data size in read_buf */ + char read_buf[4096]; /* buffer for already read but not returned data */ header_t *headers; DWORD num_headers; WCHAR **accept_types;