webservices: Use stream I/O for UDP and session-less TCP channel bindings.

Signed-off-by: Hans Leidekker <hans@codeweavers.com>
Signed-off-by: Alexandre Julliard <julliard@winehq.org>
This commit is contained in:
Hans Leidekker 2018-12-04 13:57:29 +01:00 committed by Alexandre Julliard
parent 3e326c9a19
commit 63b8da2ba2
1 changed files with 95 additions and 72 deletions

View File

@ -1144,33 +1144,9 @@ done:
return hr;
}
static void set_blocking( SOCKET socket, BOOL blocking )
{
ULONG state = !blocking;
ioctlsocket( socket, FIONBIO, &state );
}
static int sock_recv( SOCKET socket, char *buf, int len )
{
int count, ret;
if ((ret = recv( socket, buf, len, 0 )) <= 0) return ret;
len -= ret;
set_blocking( socket, FALSE );
for (;;)
{
if ((count = recv( socket, buf + ret, len, 0 )) <= 0) break;
ret += count;
len -= count;
}
set_blocking( socket, TRUE );
return ret;
}
static HRESULT receive_bytes( struct channel *channel, unsigned char *bytes, int len )
{
int count = sock_recv( channel->u.tcp.socket, (char *)bytes, len );
int count = recv( channel->u.tcp.socket, (char *)bytes, len, 0 );
if (count < 0) return HRESULT_FROM_WIN32( WSAGetLastError() );
if (count != len) return WS_E_INVALID_FORMAT;
return S_OK;
@ -1214,8 +1190,6 @@ static HRESULT send_message( struct channel *channel, WS_MESSAGE *msg )
HRESULT hr;
channel->msg = msg;
if ((hr = connect_channel( channel )) != S_OK) return hr;
WsGetMessageProperty( channel->msg, WS_MESSAGE_PROPERTY_BODY_WRITER, &writer, sizeof(writer), NULL );
WsGetWriterProperty( writer, WS_XML_WRITER_PROPERTY_BYTES, &buf, sizeof(buf), NULL );
@ -1232,7 +1206,7 @@ static HRESULT send_message( struct channel *channel, WS_MESSAGE *msg )
return send_message_http( channel->u.http.request, buf.bytes, buf.length );
case WS_TCP_CHANNEL_BINDING:
if (channel->encoding == WS_ENCODING_XML_BINARY_SESSION_1)
if (channel->type & WS_CHANNEL_TYPE_SESSION)
{
switch (channel->session_state)
{
@ -1249,10 +1223,10 @@ static HRESULT send_message( struct channel *channel, WS_MESSAGE *msg )
return WS_E_OTHER;
}
}
return send_bytes( channel->u.tcp.socket, buf.bytes, buf.length );
/* fall through */
case WS_UDP_CHANNEL_BINDING:
return send_bytes( channel->u.udp.socket, buf.bytes, buf.length );
return WsFlushWriter( writer, 0, NULL, NULL );
default:
ERR( "unhandled binding %u\n", channel->binding );
@ -1273,7 +1247,7 @@ HRESULT channel_send_message( WS_CHANNEL *handle, WS_MESSAGE *msg )
return E_INVALIDARG;
}
hr = send_message( channel, msg );
if ((hr = connect_channel( channel )) == S_OK) hr = send_message( channel, msg );
LeaveCriticalSection( &channel->cs );
return hr;
@ -1305,11 +1279,25 @@ static HRESULT CALLBACK dict_cb( void *state, const WS_XML_STRING *str, BOOL *fo
return hr;
}
static CALLBACK HRESULT write_callback( void *state, const WS_BYTES *buf, ULONG count,
const WS_ASYNC_CONTEXT *ctx, WS_ERROR *error )
{
SOCKET socket = *(SOCKET *)state;
if (send( socket, (const char *)buf->bytes, buf->length, 0 ) < 0)
{
TRACE( "send failed %u\n", WSAGetLastError() );
}
return S_OK;
}
static HRESULT init_writer( struct channel *channel )
{
WS_XML_WRITER_BUFFER_OUTPUT buf = {{WS_XML_WRITER_OUTPUT_TYPE_BUFFER}};
WS_XML_WRITER_STREAM_OUTPUT stream = {{WS_XML_WRITER_OUTPUT_TYPE_STREAM}};
WS_XML_WRITER_TEXT_ENCODING text = {{WS_XML_WRITER_ENCODING_TYPE_TEXT}, WS_CHARSET_UTF8};
WS_XML_WRITER_BINARY_ENCODING bin = {{WS_XML_WRITER_ENCODING_TYPE_BINARY}};
const WS_XML_WRITER_ENCODING *encoding;
const WS_XML_WRITER_OUTPUT *output;
WS_XML_WRITER_PROPERTY prop;
ULONG max_size = (1 << 17);
HRESULT hr;
@ -1322,19 +1310,33 @@ static HRESULT init_writer( struct channel *channel )
switch (channel->encoding)
{
case WS_ENCODING_XML_UTF8:
return WsSetOutput( channel->writer, &text.encoding, &buf.output, NULL, 0, NULL );
encoding = &text.encoding;
if (channel->binding == WS_UDP_CHANNEL_BINDING ||
(channel->binding == WS_TCP_CHANNEL_BINDING && !(channel->type & WS_CHANNEL_TYPE_SESSION)))
{
stream.writeCallback = write_callback;
stream.writeCallbackState = (channel->binding == WS_UDP_CHANNEL_BINDING) ?
&channel->u.udp.socket : &channel->u.tcp.socket;
output = &stream.output;
}
else output = &buf.output;
break;
case WS_ENCODING_XML_BINARY_SESSION_1:
bin.staticDictionary = (WS_XML_DICTIONARY *)&dict_builtin_static.dict;
/* fall through */
case WS_ENCODING_XML_BINARY_1:
return WsSetOutput( channel->writer, &bin.encoding, &buf.output, NULL, 0, NULL );
encoding = &bin.encoding;
output = &buf.output;
break;
default:
FIXME( "unhandled encoding %u\n", channel->encoding );
return WS_E_NOT_SUPPORTED;
}
return WsSetOutput( channel->writer, encoding, output, NULL, 0, NULL );
}
static HRESULT write_message( struct channel *channel, WS_MESSAGE *msg, const WS_ELEMENT_DESCRIPTION *desc,
@ -1378,6 +1380,7 @@ HRESULT WINAPI WsSendMessage( WS_CHANNEL *handle, WS_MESSAGE *msg, const WS_MESS
if ((hr = WsAddressMessage( msg, &channel->addr, NULL )) != S_OK) goto done;
if ((hr = message_set_action( msg, desc->action )) != S_OK) goto done;
if ((hr = connect_channel( channel )) != S_OK) goto done;
if ((hr = init_writer( channel )) != S_OK) goto done;
if ((hr = write_message( channel, msg, desc->bodyElementDescription, option, body, size )) != S_OK) goto done;
hr = send_message( channel, msg );
@ -1419,6 +1422,7 @@ HRESULT WINAPI WsSendReplyMessage( WS_CHANNEL *handle, WS_MESSAGE *msg, const WS
if ((hr = message_get_id( request, &req_id )) != S_OK) goto done;
if ((hr = message_set_request_id( msg, &req_id )) != S_OK) goto done;
if ((hr = connect_channel( channel )) != S_OK) goto done;
if ((hr = init_writer( channel )) != S_OK) goto done;
if ((hr = write_message( channel, msg, desc->bodyElementDescription, option, body, size )) != S_OK) goto done;
hr = send_message( channel, msg );
@ -1448,12 +1452,29 @@ static HRESULT resize_read_buffer( struct channel *channel, ULONG size )
return S_OK;
}
static CALLBACK HRESULT read_callback( void *state, void *buf, ULONG buflen, ULONG *retlen,
const WS_ASYNC_CONTEXT *ctx, WS_ERROR *error )
{
SOCKET socket = *(SOCKET *)state;
int ret;
if ((ret = recv( socket, buf, buflen, 0 )) >= 0) *retlen = ret;
else
{
TRACE( "recv failed %u\n", WSAGetLastError() );
*retlen = 0;
}
return S_OK;
}
static HRESULT init_reader( struct channel *channel )
{
WS_XML_READER_BUFFER_INPUT buf = {{WS_XML_READER_INPUT_TYPE_BUFFER}};
WS_XML_READER_STREAM_INPUT stream = {{WS_XML_READER_INPUT_TYPE_STREAM}};
WS_XML_READER_TEXT_ENCODING text = {{WS_XML_READER_ENCODING_TYPE_TEXT}};
WS_XML_READER_BINARY_ENCODING bin = {{WS_XML_READER_ENCODING_TYPE_BINARY}};
WS_XML_READER_ENCODING *encoding;
const WS_XML_READER_ENCODING *encoding;
const WS_XML_READER_INPUT *input;
HRESULT hr;
if (!channel->reader && (hr = WsCreateReader( NULL, 0, &channel->reader, NULL )) != S_OK) return hr;
@ -1463,6 +1484,21 @@ static HRESULT init_reader( struct channel *channel )
case WS_ENCODING_XML_UTF8:
text.charSet = WS_CHARSET_UTF8;
encoding = &text.encoding;
if (channel->binding == WS_UDP_CHANNEL_BINDING ||
(channel->binding == WS_TCP_CHANNEL_BINDING && !(channel->type & WS_CHANNEL_TYPE_SESSION)))
{
stream.readCallback = read_callback;
stream.readCallbackState = (channel->binding == WS_UDP_CHANNEL_BINDING) ?
&channel->u.udp.socket : &channel->u.tcp.socket;
input = &stream.input;
}
else
{
buf.encodedData = channel->read_buf;
buf.encodedDataSize = channel->read_size;
input = &buf.input;
}
break;
case WS_ENCODING_XML_BINARY_SESSION_1:
@ -1472,6 +1508,10 @@ static HRESULT init_reader( struct channel *channel )
case WS_ENCODING_XML_BINARY_1:
encoding = &bin.encoding;
buf.encodedData = channel->read_buf;
buf.encodedDataSize = channel->read_size;
input = &buf.input;
break;
default:
@ -1479,9 +1519,7 @@ static HRESULT init_reader( struct channel *channel )
return WS_E_NOT_SUPPORTED;
}
buf.encodedData = channel->read_buf;
buf.encodedDataSize = channel->read_size;
return WsSetInput( channel->reader, encoding, &buf.input, NULL, 0, NULL );
return WsSetInput( channel->reader, encoding, input, NULL, 0, NULL );
}
#define INITIAL_READ_BUFFER_SIZE 4096
@ -1515,26 +1553,6 @@ static HRESULT receive_message_http( struct channel *channel )
offset += bytes_read;
}
return init_reader( channel );
}
static HRESULT receive_message_unsized( struct channel *channel, SOCKET socket )
{
int bytes_read;
ULONG max_len;
HRESULT hr;
prop_get( channel->prop, channel->prop_count, WS_CHANNEL_PROPERTY_MAX_BUFFERED_MESSAGE_SIZE,
&max_len, sizeof(max_len) );
if ((hr = resize_read_buffer( channel, max_len )) != S_OK) return hr;
channel->read_size = 0;
if ((bytes_read = sock_recv( socket, channel->read_buf, max_len )) < 0)
{
return HRESULT_FROM_WIN32( WSAGetLastError() );
}
channel->read_size = bytes_read;
return S_OK;
}
@ -1549,7 +1567,7 @@ static HRESULT receive_message_sized( struct channel *channel, unsigned int size
channel->read_size = 0;
while (channel->read_size < size)
{
if ((bytes_read = sock_recv( channel->u.tcp.socket, channel->read_buf + offset, to_read )) < 0)
if ((bytes_read = recv( channel->u.tcp.socket, channel->read_buf + offset, to_read, 0 )) < 0)
{
return HRESULT_FROM_WIN32( WSAGetLastError() );
}
@ -1798,14 +1816,7 @@ static HRESULT receive_message_session( struct channel *channel )
memmove( channel->read_buf, channel->read_buf + size, channel->read_size );
}
return init_reader( channel );
}
static HRESULT receive_message_sock( struct channel *channel, SOCKET socket )
{
HRESULT hr;
if ((hr = receive_message_unsized( channel, socket )) != S_OK) return hr;
return init_reader( channel );
return S_OK;
}
static HRESULT receive_message_bytes( struct channel *channel )
@ -1819,7 +1830,7 @@ static HRESULT receive_message_bytes( struct channel *channel )
return receive_message_http( channel );
case WS_TCP_CHANNEL_BINDING:
if (channel->encoding == WS_ENCODING_XML_BINARY_SESSION_1)
if (channel->type & WS_CHANNEL_TYPE_SESSION)
{
switch (channel->session_state)
{
@ -1836,10 +1847,10 @@ static HRESULT receive_message_bytes( struct channel *channel )
return WS_E_OTHER;
}
}
return receive_message_sock( channel, channel->u.tcp.socket );
return S_OK; /* nothing to do, data is read through stream callback */
case WS_UDP_CHANNEL_BINDING:
return receive_message_sock( channel, channel->u.udp.socket );
return S_OK;
default:
ERR( "unhandled binding %u\n", channel->binding );
@ -1860,7 +1871,7 @@ HRESULT channel_receive_message( WS_CHANNEL *handle )
return E_INVALIDARG;
}
hr = receive_message_bytes( channel );
if ((hr = receive_message_bytes( channel )) == S_OK) hr = init_reader( channel );
LeaveCriticalSection( &channel->cs );
return hr;
@ -1901,6 +1912,8 @@ static HRESULT receive_message( struct channel *channel, WS_MESSAGE *msg, const
ULONG i;
if ((hr = receive_message_bytes( channel )) != S_OK) return hr;
if ((hr = init_reader( channel )) != S_OK) return hr;
for (i = 0; i < count; i++)
{
const WS_ELEMENT_DESCRIPTION *body = desc[i]->bodyElementDescription;
@ -2012,6 +2025,7 @@ static HRESULT request_reply( struct channel *channel, WS_MESSAGE *request,
if ((hr = WsAddressMessage( request, &channel->addr, NULL )) != S_OK) return hr;
if ((hr = message_set_action( request, request_desc->action )) != S_OK) return hr;
if ((hr = connect_channel( channel )) != S_OK) return hr;
if ((hr = init_writer( channel )) != S_OK) return hr;
if ((hr = write_message( channel, request, request_desc->bodyElementDescription, write_option, request_body,
request_size )) != S_OK) return hr;
@ -2141,7 +2155,8 @@ HRESULT WINAPI WsReadMessageStart( WS_CHANNEL *handle, WS_MESSAGE *msg, const WS
if ((hr = receive_message_bytes( channel )) == S_OK)
{
hr = WsReadEnvelopeStart( msg, channel->reader, NULL, NULL, NULL );
if ((hr = init_reader( channel )) == S_OK)
hr = WsReadEnvelopeStart( msg, channel->reader, NULL, NULL, NULL );
}
LeaveCriticalSection( &channel->cs );
@ -2202,6 +2217,7 @@ HRESULT WINAPI WsWriteMessageStart( WS_CHANNEL *handle, WS_MESSAGE *msg, const W
return E_INVALIDARG;
}
if ((hr = connect_channel( channel )) != S_OK) goto done;
if ((hr = init_writer( channel )) != S_OK) goto done;
if ((hr = WsAddressMessage( msg, &channel->addr, NULL )) != S_OK) goto done;
hr = WsWriteEnvelopeStart( msg, channel->writer, NULL, NULL, NULL );
@ -2235,13 +2251,20 @@ HRESULT WINAPI WsWriteMessageEnd( WS_CHANNEL *handle, WS_MESSAGE *msg, const WS_
return E_INVALIDARG;
}
if ((hr = WsWriteEnvelopeEnd( msg, NULL )) == S_OK) hr = send_message( channel, msg );
if ((hr = WsWriteEnvelopeEnd( msg, NULL )) == S_OK && (hr = connect_channel( channel ) == S_OK))
hr = send_message( channel, msg );
LeaveCriticalSection( &channel->cs );
TRACE( "returning %08x\n", hr );
return hr;
}
static void set_blocking( SOCKET socket, BOOL blocking )
{
ULONG state = !blocking;
ioctlsocket( socket, FIONBIO, &state );
}
static HRESULT sock_accept( SOCKET socket, HANDLE wait, HANDLE cancel, SOCKET *ret )
{
HANDLE handles[] = { wait, cancel };