diff --git a/dlls/wininet/netconnection.c b/dlls/wininet/netconnection.c index 4c979243341..fa07a9bdb11 100644 --- a/dlls/wininet/netconnection.c +++ b/dlls/wininet/netconnection.c @@ -86,6 +86,11 @@ #define RESPONSE_TIMEOUT 30 /* FROM internet.c */ +#ifdef MSG_DONTWAIT +#define WINE_MSG_DONTWAIT MSG_DONTWAIT +#else +#define WINE_MSG_DONTWAIT 0 +#endif WINE_DEFAULT_DEBUG_CHANNEL(wininet); @@ -755,37 +760,53 @@ DWORD NETCON_send(netconn_t *connection, const void *msg, size_t len, int flags, } } -static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, SIZE_T *ret_size, BOOL *eof) +static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, blocking_mode_t mode, SIZE_T *ret_size, BOOL *eof) { const SIZE_T ssl_buf_size = conn->ssl_sizes.cbHeader+conn->ssl_sizes.cbMaximumMessage+conn->ssl_sizes.cbTrailer; SecBuffer bufs[4]; SecBufferDesc buf_desc = {SECBUFFER_VERSION, sizeof(bufs)/sizeof(*bufs), bufs}; - SSIZE_T size, buf_len; + SSIZE_T size, buf_len = 0; + blocking_mode_t tmp_mode; int i; SECURITY_STATUS res; assert(conn->extra_len < ssl_buf_size); + /* BLOCKING_WAITALL is handled by caller */ + if(mode == BLOCKING_WAITALL) + mode = BLOCKING_ALLOW; + if(conn->extra_len) { memcpy(conn->ssl_buf, conn->extra_buf, conn->extra_len); buf_len = conn->extra_len; conn->extra_len = 0; heap_free(conn->extra_buf); conn->extra_buf = NULL; - }else { - buf_len = recv(conn->socket, conn->ssl_buf+conn->extra_len, ssl_buf_size-conn->extra_len, 0); - if(buf_len < 0) { - WARN("recv failed\n"); - return FALSE; - } - - if(!buf_len) { - *eof = TRUE; - return TRUE; - } } - *ret_size = 0; + tmp_mode = buf_len ? BLOCKING_DISALLOW : mode; + set_socket_blocking(conn->socket, tmp_mode); + size = recv(conn->socket, conn->ssl_buf+buf_len, ssl_buf_size-buf_len, tmp_mode == BLOCKING_ALLOW ? 0 : WINE_MSG_DONTWAIT); + if(size < 0) { + if(!buf_len) { + if(errno == EAGAIN || errno == EWOULDBLOCK) { + TRACE("would block\n"); + return WSAEWOULDBLOCK; + } + WARN("recv failed\n"); + return ERROR_INTERNET_CONNECTION_ABORTED; + } + }else { + buf_len += size; + } + + *ret_size = buf_len; + + if(!buf_len) { + *eof = TRUE; + return ERROR_SUCCESS; + } + *eof = FALSE; do { @@ -801,19 +822,34 @@ static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, SIZE_T * case SEC_I_CONTEXT_EXPIRED: TRACE("context expired\n"); *eof = TRUE; - return TRUE; + return ERROR_SUCCESS; case SEC_E_INCOMPLETE_MESSAGE: assert(buf_len < ssl_buf_size); - size = recv(conn->socket, conn->ssl_buf+buf_len, ssl_buf_size-buf_len, 0); - if(size < 1) - return FALSE; + set_socket_blocking(conn->socket, mode); + size = recv(conn->socket, conn->ssl_buf+buf_len, ssl_buf_size-buf_len, mode == BLOCKING_ALLOW ? 0 : WINE_MSG_DONTWAIT); + if(size < 1) { + if(size < 0 && (errno == EAGAIN || errno == EWOULDBLOCK)) { + TRACE("would block\n"); + + /* FIXME: Optimize extra_buf usage. */ + conn->extra_buf = heap_alloc(buf_len); + if(!conn->extra_buf) + return ERROR_NOT_ENOUGH_MEMORY; + + conn->extra_len = buf_len; + memcpy(conn->extra_buf, conn->ssl_buf, conn->extra_len); + return WSAEWOULDBLOCK; + } + + return ERROR_INTERNET_CONNECTION_ABORTED; + } buf_len += size; continue; default: WARN("failed: %08x\n", res); - return FALSE; + return ERROR_INTERNET_CONNECTION_ABORTED; } } while(res != SEC_E_OK); @@ -825,7 +861,7 @@ static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, SIZE_T * assert(!conn->peek_len); conn->peek_msg_mem = conn->peek_msg = heap_alloc(bufs[i].cbBuffer - size); if(!conn->peek_msg) - return FALSE; + return ERROR_NOT_ENOUGH_MEMORY; conn->peek_len = bufs[i].cbBuffer-size; memcpy(conn->peek_msg, (char*)bufs[i].pvBuffer+size, conn->peek_len); } @@ -838,14 +874,14 @@ static BOOL read_ssl_chunk(netconn_t *conn, void *buf, SIZE_T buf_size, SIZE_T * if(bufs[i].BufferType == SECBUFFER_EXTRA) { conn->extra_buf = heap_alloc(bufs[i].cbBuffer); if(!conn->extra_buf) - return FALSE; + return ERROR_NOT_ENOUGH_MEMORY; conn->extra_len = bufs[i].cbBuffer; memcpy(conn->extra_buf, bufs[i].pvBuffer, conn->extra_len); } } - return TRUE; + return ERROR_SUCCESS; } /****************************************************************************** @@ -867,9 +903,7 @@ DWORD NETCON_recv(netconn_t *connection, void *buf, size_t len, blocking_mode_t case BLOCKING_ALLOW: break; case BLOCKING_DISALLOW: -#ifdef MSG_DONTWAIT - flags = MSG_DONTWAIT; -#endif + flags = WINE_MSG_DONTWAIT; break; case BLOCKING_WAITALL: flags = MSG_WAITALL; @@ -883,7 +917,8 @@ DWORD NETCON_recv(netconn_t *connection, void *buf, size_t len, blocking_mode_t else { SIZE_T size = 0, cread; - BOOL res, eof; + BOOL eof; + DWORD res; if(connection->peek_msg) { size = min(len, connection->peek_len); @@ -900,18 +935,19 @@ DWORD NETCON_recv(netconn_t *connection, void *buf, size_t len, blocking_mode_t *recvd = size; return ERROR_SUCCESS; } + + mode = BLOCKING_DISALLOW; } - if(mode == BLOCKING_DISALLOW) - return WSAEWOULDBLOCK; /* FIXME: We can do better */ - set_socket_blocking(connection->socket, BLOCKING_ALLOW); - do { - res = read_ssl_chunk(connection, (BYTE*)buf+size, len-size, &cread, &eof); - if(!res) { - WARN("read_ssl_chunk failed\n"); - if(!size) - return ERROR_INTERNET_CONNECTION_ABORTED; + res = read_ssl_chunk(connection, (BYTE*)buf+size, len-size, mode, &cread, &eof); + if(res != ERROR_SUCCESS) { + if(res == WSAEWOULDBLOCK) { + if(size) + res = ERROR_SUCCESS; + }else { + WARN("read_ssl_chunk failed\n"); + } break; } @@ -925,7 +961,7 @@ DWORD NETCON_recv(netconn_t *connection, void *buf, size_t len, blocking_mode_t TRACE("received %ld bytes\n", size); *recvd = size; - return ERROR_SUCCESS; + return res; } }