diff --git a/dlls/wininet/internet.h b/dlls/wininet/internet.h index 9fbb74f2d33..0161935a966 100644 --- a/dlls/wininet/internet.h +++ b/dlls/wininet/internet.h @@ -98,6 +98,9 @@ typedef struct char *ssl_buf; char *extra_buf; size_t extra_len; + char *peek_msg; + char *peek_msg_mem; + size_t peek_len; DWORD security_flags; BOOL mask_errors; diff --git a/dlls/wininet/netconnection.c b/dlls/wininet/netconnection.c index 53308b9f03f..4abf7f195ad 100644 --- a/dlls/wininet/netconnection.c +++ b/dlls/wininet/netconnection.c @@ -759,8 +759,15 @@ void free_netconn(netconn_t *netconn) pSSL_shutdown(netconn->ssl_s); pSSL_free(netconn->ssl_s); #else + heap_free(netconn->peek_msg_mem); + netconn->peek_msg_mem = NULL; + netconn->peek_msg = NULL; + netconn->peek_len = 0; heap_free(netconn->ssl_buf); netconn->ssl_buf = NULL; + heap_free(netconn->extra_buf); + netconn->extra_buf = NULL; + netconn->extra_len = 0; DeleteSecurityContext(&netconn->ssl_ctx); #endif } @@ -1204,6 +1211,101 @@ DWORD NETCON_send(netconn_t *connection, const void *msg, size_t len, int flags, } } +#ifndef SONAME_LIBSSL +static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, SIZE_T *ret_size, BOOL *eof) +{ + const SIZE_T ssl_buf_size = conn->ssl_sizes.cbHeader+conn->ssl_sizes.cbMaximumMessage+conn->ssl_sizes.cbTrailer; + SecBuffer bufs[4]; + SecBufferDesc buf_desc = {SECBUFFER_VERSION, sizeof(bufs)/sizeof(*bufs), bufs}; + SSIZE_T size, buf_len; + int i; + SECURITY_STATUS res; + + assert(conn->extra_len < ssl_buf_size); + + if(conn->extra_len) { + memcpy(conn->ssl_buf, conn->extra_buf, conn->extra_len); + buf_len = conn->extra_len; + conn->extra_len = 0; + heap_free(conn->extra_buf); + conn->extra_buf = NULL; + }else { + buf_len = recv(conn->socket, conn->ssl_buf+conn->extra_len, ssl_buf_size-conn->extra_len, 0); + if(buf_len < 0) { + WARN("recv failed\n"); + return FALSE; + } + + if(!buf_len) { + *eof = TRUE; + return TRUE; + } + } + + *ret_size = 0; + *eof = FALSE; + + do { + memset(bufs, 0, sizeof(bufs)); + bufs[0].BufferType = SECBUFFER_DATA; + bufs[0].cbBuffer = buf_len; + bufs[0].pvBuffer = conn->ssl_buf; + + res = DecryptMessage(&conn->ssl_ctx, &buf_desc, 0, NULL); + switch(res) { + case SEC_E_OK: + break; + case SEC_I_CONTEXT_EXPIRED: + TRACE("context expired\n"); + *eof = TRUE; + return TRUE; + case SEC_E_INCOMPLETE_MESSAGE: + assert(buf_len < ssl_buf_size); + + size = recv(conn->socket, conn->ssl_buf+buf_len, ssl_buf_size-buf_len, 0); + if(size < 1) + return FALSE; + + buf_len += size; + continue; + default: + WARN("failed: %08x\n", res); + return FALSE; + } + } while(res != SEC_E_OK); + + for(i=0; i < sizeof(bufs)/sizeof(*bufs); i++) { + if(bufs[i].BufferType == SECBUFFER_DATA) { + size = min(buf_size, bufs[i].cbBuffer); + memcpy(buf, bufs[i].pvBuffer, size); + if(size < bufs[i].cbBuffer) { + assert(!conn->peek_len); + conn->peek_msg_mem = conn->peek_msg = heap_alloc(bufs[i].cbBuffer - size); + if(!conn->peek_msg) + return FALSE; + conn->peek_len = bufs[i].cbBuffer-size; + memcpy(conn->peek_msg, (char*)bufs[i].pvBuffer+size, conn->peek_len); + } + + *ret_size = size; + } + } + + for(i=0; i < sizeof(bufs)/sizeof(*bufs); i++) { + if(bufs[i].BufferType == SECBUFFER_EXTRA) { + conn->extra_buf = heap_alloc(bufs[i].cbBuffer); + if(!conn->extra_buf) + return FALSE; + + conn->extra_len = bufs[i].cbBuffer; + memcpy(conn->extra_buf, bufs[i].pvBuffer, conn->extra_len); + } + } + + return TRUE; +} +#endif + /****************************************************************************** * NETCON_recv * Basically calls 'recv()' unless we should use SSL @@ -1236,8 +1338,46 @@ DWORD NETCON_recv(netconn_t *connection, void *buf, size_t len, int flags, int * return *recvd > 0 ? ERROR_SUCCESS : ERROR_INTERNET_CONNECTION_ABORTED; #else - FIXME("not supported on this platform\n"); - return ERROR_NOT_SUPPORTED; + SIZE_T size = 0, cread; + BOOL res, eof; + + if(connection->peek_msg) { + size = min(len, connection->peek_len); + memcpy(buf, connection->peek_msg, size); + connection->peek_len -= size; + connection->peek_msg += size; + + if(!connection->peek_len) { + heap_free(connection->peek_msg_mem); + connection->peek_msg_mem = connection->peek_msg = NULL; + } + /* check if we have enough data from the peek buffer */ + if(!(flags & MSG_WAITALL) || size == len) { + *recvd = size; + return ERROR_SUCCESS; + } + } + + do { + res = read_ssl_chunk(connection, (BYTE*)buf+size, len-size, &cread, &eof); + if(!res) { + WARN("read_ssl_chunk failed\n"); + if(!size) + return ERROR_INTERNET_CONNECTION_ABORTED; + break; + } + + if(eof) { + TRACE("EOF\n"); + break; + } + + size += cread; + }while(!size || ((flags & MSG_WAITALL) && size < len)); + + TRACE("received %ld bytes\n", size); + *recvd = size; + return ERROR_SUCCESS; #endif } }