rpcrt4: Move the receiving of an individual fragment to a separate function.

This commit is contained in:
Rob Shearman 2008-01-21 10:31:55 +00:00 committed by Alexandre Julliard
parent 5f077bab07
commit 2b0d3b7400
1 changed files with 131 additions and 57 deletions

View File

@ -683,29 +683,51 @@ RPC_STATUS RPCRT4_Send(RpcConnection *Connection, RpcPktHdr *Header,
return r; return r;
} }
/* validates version and frag_len fields */
RPC_STATUS RPCRT4_ValidateCommonHeader(const RpcPktCommonHdr *hdr)
{
DWORD hdr_length;
/* verify if the header really makes sense */
if (hdr->rpc_ver != RPC_VER_MAJOR ||
hdr->rpc_ver_minor != RPC_VER_MINOR)
{
WARN("unhandled packet version\n");
return RPC_S_PROTOCOL_ERROR;
}
hdr_length = RPCRT4_GetHeaderSize((const RpcPktHdr*)hdr);
if (hdr_length == 0)
{
WARN("header length == 0\n");
return RPC_S_PROTOCOL_ERROR;
}
if (hdr->frag_len < hdr_length)
{
WARN("bad frag length %d\n", hdr->frag_len);
return RPC_S_PROTOCOL_ERROR;
}
return RPC_S_OK;
}
/*********************************************************************** /***********************************************************************
* RPCRT4_Receive (internal) * RPCRT4_receive_fragment (internal)
* *
* Receive a packet from connection and merge the fragments. * Receive a fragment from a connection.
*/ */
RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header, RPC_STATUS RPCRT4_receive_fragment(RpcConnection *Connection, RpcPktHdr **Header, void **Payload)
PRPC_MESSAGE pMsg)
{ {
RPC_STATUS status; RPC_STATUS status;
DWORD hdr_length; DWORD hdr_length;
LONG dwRead; LONG dwRead;
unsigned short first_flag;
unsigned long data_length;
unsigned long buffer_length;
unsigned long auth_length;
unsigned char *auth_data = NULL;
RpcPktCommonHdr common_hdr; RpcPktCommonHdr common_hdr;
*Header = NULL; *Header = NULL;
*Payload = NULL;
TRACE("(%p, %p, %p)\n", Connection, Header, pMsg); TRACE("(%p, %p, %p)\n", Connection, Header, Payload);
RPCRT4_SetThreadCurrentConnection(Connection);
/* read packet common header */ /* read packet common header */
dwRead = rpcrt4_conn_read(Connection, &common_hdr, sizeof(common_hdr)); dwRead = rpcrt4_conn_read(Connection, &common_hdr, sizeof(common_hdr));
@ -715,13 +737,8 @@ RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header,
goto fail; goto fail;
} }
/* verify if the header really makes sense */ status = RPCRT4_ValidateCommonHeader(&common_hdr);
if (common_hdr.rpc_ver != RPC_VER_MAJOR || if (status != RPC_S_OK) goto fail;
common_hdr.rpc_ver_minor != RPC_VER_MINOR) {
WARN("unhandled packet version\n");
status = RPC_S_PROTOCOL_ERROR;
goto fail;
}
hdr_length = RPCRT4_GetHeaderSize((RpcPktHdr*)&common_hdr); hdr_length = RPCRT4_GetHeaderSize((RpcPktHdr*)&common_hdr);
if (hdr_length == 0) { if (hdr_length == 0) {
@ -741,8 +758,70 @@ RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header,
goto fail; goto fail;
} }
if (common_hdr.frag_len - hdr_length)
{
*Payload = HeapAlloc(GetProcessHeap(), 0, common_hdr.frag_len - hdr_length);
if (!*Payload)
{
status = RPC_S_OUT_OF_RESOURCES;
goto fail;
}
dwRead = rpcrt4_conn_read(Connection, *Payload, common_hdr.frag_len - hdr_length);
if (dwRead != common_hdr.frag_len - hdr_length)
{
WARN("bad data length, %d/%d\n", dwRead, common_hdr.frag_len - hdr_length);
status = RPC_S_CALL_FAILED;
goto fail;
}
}
else
*Payload = NULL;
/* success */
status = RPC_S_OK;
fail:
if (status != RPC_S_OK) {
RPCRT4_FreeHeader(*Header);
*Header = NULL;
HeapFree(GetProcessHeap(), 0, *Payload);
*Payload = NULL;
}
return status;
}
/***********************************************************************
* RPCRT4_Receive (internal)
*
* Receive a packet from connection and merge the fragments.
*/
RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header,
PRPC_MESSAGE pMsg)
{
RPC_STATUS status;
DWORD hdr_length;
unsigned short first_flag;
unsigned long data_length;
unsigned long buffer_length;
unsigned long auth_length;
unsigned char *auth_data = NULL;
RpcPktHdr *CurrentHeader;
void *payload = NULL;
*Header = NULL;
TRACE("(%p, %p, %p)\n", Connection, Header, pMsg);
RPCRT4_SetThreadCurrentConnection(Connection);
status = RPCRT4_receive_fragment(Connection, Header, &payload);
if (status != RPC_S_OK) goto fail;
hdr_length = RPCRT4_GetHeaderSize(*Header);
/* read packet body */ /* read packet body */
switch (common_hdr.ptype) { switch ((*Header)->common.ptype) {
case PKT_RESPONSE: case PKT_RESPONSE:
pMsg->BufferLength = (*Header)->response.alloc_hint; pMsg->BufferLength = (*Header)->response.alloc_hint;
break; break;
@ -750,7 +829,7 @@ RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header,
pMsg->BufferLength = (*Header)->request.alloc_hint; pMsg->BufferLength = (*Header)->request.alloc_hint;
break; break;
default: default:
pMsg->BufferLength = common_hdr.frag_len - hdr_length - RPC_AUTH_VERIFIER_LEN(&common_hdr); pMsg->BufferLength = (*Header)->common.frag_len - hdr_length - RPC_AUTH_VERIFIER_LEN(&(*Header)->common);
} }
TRACE("buffer length = %u\n", pMsg->BufferLength); TRACE("buffer length = %u\n", pMsg->BufferLength);
@ -763,30 +842,31 @@ RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header,
} }
first_flag = RPC_FLG_FIRST; first_flag = RPC_FLG_FIRST;
auth_length = common_hdr.auth_len; auth_length = (*Header)->common.auth_len;
if (auth_length) { if (auth_length) {
auth_data = HeapAlloc(GetProcessHeap(), 0, RPC_AUTH_VERIFIER_LEN(&common_hdr)); auth_data = HeapAlloc(GetProcessHeap(), 0, RPC_AUTH_VERIFIER_LEN(&(*Header)->common));
if (!auth_data) { if (!auth_data) {
status = RPC_S_OUT_OF_RESOURCES; status = RPC_S_OUT_OF_RESOURCES;
goto fail; goto fail;
} }
} }
CurrentHeader = *Header;
buffer_length = 0; buffer_length = 0;
while (TRUE) while (TRUE)
{ {
unsigned int header_auth_len = RPC_AUTH_VERIFIER_LEN(&(*Header)->common); unsigned int header_auth_len = RPC_AUTH_VERIFIER_LEN(&CurrentHeader->common);
/* verify header fields */ /* verify header fields */
if (((*Header)->common.frag_len < hdr_length) || if ((CurrentHeader->common.frag_len < hdr_length) ||
((*Header)->common.frag_len - hdr_length < header_auth_len)) { (CurrentHeader->common.frag_len - hdr_length < header_auth_len)) {
WARN("frag_len %d too small for hdr_length %d and auth_len %d\n", WARN("frag_len %d too small for hdr_length %d and auth_len %d\n",
(*Header)->common.frag_len, hdr_length, header_auth_len); CurrentHeader->common.frag_len, hdr_length, CurrentHeader->common.auth_len);
status = RPC_S_PROTOCOL_ERROR; status = RPC_S_PROTOCOL_ERROR;
goto fail; goto fail;
} }
if ((*Header)->common.auth_len != auth_length) { if ((CurrentHeader->common.flags & RPC_FLG_FIRST) != first_flag) {
WARN("auth_len header field changed from %ld to %d\n", WARN("auth_len header field changed from %ld to %d\n",
auth_length, (*Header)->common.auth_len); auth_length, (*Header)->common.auth_len);
status = RPC_S_PROTOCOL_ERROR; status = RPC_S_PROTOCOL_ERROR;
@ -799,7 +879,7 @@ RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header,
goto fail; goto fail;
} }
data_length = (*Header)->common.frag_len - hdr_length - header_auth_len; data_length = CurrentHeader->common.frag_len - hdr_length - header_auth_len;
if (data_length + buffer_length > pMsg->BufferLength) { if (data_length + buffer_length > pMsg->BufferLength) {
TRACE("allocation hint exceeded, new buffer length = %ld\n", TRACE("allocation hint exceeded, new buffer length = %ld\n",
data_length + buffer_length); data_length + buffer_length);
@ -808,17 +888,11 @@ RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header,
if (status != RPC_S_OK) goto fail; if (status != RPC_S_OK) goto fail;
} }
if (data_length == 0) dwRead = 0; else memcpy((unsigned char *)pMsg->Buffer + buffer_length, payload, data_length);
dwRead = rpcrt4_conn_read(Connection,
(unsigned char *)pMsg->Buffer + buffer_length, data_length);
if (dwRead != data_length) {
WARN("bad data length, %d/%ld\n", dwRead, data_length);
status = RPC_S_CALL_FAILED;
goto fail;
}
if (header_auth_len) { if (header_auth_len) {
if (header_auth_len < sizeof(RpcAuthVerifier)) { if (header_auth_len < sizeof(RpcAuthVerifier) ||
header_auth_len > RPC_AUTH_VERIFIER_LEN(&(*Header)->common)) {
WARN("bad auth verifier length %d\n", header_auth_len); WARN("bad auth verifier length %d\n", header_auth_len);
status = RPC_S_PROTOCOL_ERROR; status = RPC_S_PROTOCOL_ERROR;
goto fail; goto fail;
@ -829,22 +903,16 @@ RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header,
* however, the details of how this is done is very sketchy in the * however, the details of how this is done is very sketchy in the
* DCE/RPC spec. for all other packet types that have authentication * DCE/RPC spec. for all other packet types that have authentication
* verifier data then it is just duplicated in all the fragments */ * verifier data then it is just duplicated in all the fragments */
dwRead = rpcrt4_conn_read(Connection, auth_data, header_auth_len); memcpy(auth_data, (unsigned char *)payload + data_length, header_auth_len);
if (dwRead != header_auth_len) {
WARN("bad authentication data length, %d/%d\n", dwRead,
header_auth_len);
status = RPC_S_CALL_FAILED;
goto fail;
}
/* these packets are handled specially, not by the generic SecurePacket /* these packets are handled specially, not by the generic SecurePacket
* function */ * function */
if ((common_hdr.ptype != PKT_BIND) && if (((*Header)->common.ptype != PKT_BIND) &&
(common_hdr.ptype != PKT_BIND_ACK) && ((*Header)->common.ptype != PKT_BIND_ACK) &&
(common_hdr.ptype != PKT_AUTH3)) ((*Header)->common.ptype != PKT_AUTH3))
{ {
status = RPCRT4_SecurePacket(Connection, SECURE_PACKET_RECEIVE, status = RPCRT4_SecurePacket(Connection, SECURE_PACKET_RECEIVE,
*Header, hdr_length, CurrentHeader, hdr_length,
(unsigned char *)pMsg->Buffer + buffer_length, data_length, (unsigned char *)pMsg->Buffer + buffer_length, data_length,
(RpcAuthVerifier *)auth_data, (RpcAuthVerifier *)auth_data,
auth_data + sizeof(RpcAuthVerifier), auth_data + sizeof(RpcAuthVerifier),
@ -854,16 +922,19 @@ RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header,
} }
buffer_length += data_length; buffer_length += data_length;
if (!((*Header)->common.flags & RPC_FLG_LAST)) { if (!(CurrentHeader->common.flags & RPC_FLG_LAST)) {
TRACE("next header\n"); TRACE("next header\n");
/* read the header of next packet */ if (*Header != CurrentHeader)
dwRead = rpcrt4_conn_read(Connection, *Header, hdr_length); {
if (dwRead != hdr_length) { RPCRT4_FreeHeader(CurrentHeader);
WARN("invalid packet header size (%d)\n", dwRead); CurrentHeader = NULL;
status = RPC_S_CALL_FAILED;
goto fail;
} }
HeapFree(GetProcessHeap(), 0, payload);
payload = NULL;
status = RPCRT4_receive_fragment(Connection, &CurrentHeader, &payload);
if (status != RPC_S_OK) goto fail;
first_flag = 0; first_flag = 0;
} else { } else {
@ -873,7 +944,7 @@ RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header,
pMsg->BufferLength = buffer_length; pMsg->BufferLength = buffer_length;
/* respond to authorization request */ /* respond to authorization request */
if (common_hdr.ptype == PKT_BIND_ACK && auth_length > sizeof(RpcAuthVerifier)) if ((*Header)->common.ptype == PKT_BIND_ACK && auth_length > sizeof(RpcAuthVerifier))
{ {
status = RPCRT_AuthorizeConnection(Connection, status = RPCRT_AuthorizeConnection(Connection,
auth_data + sizeof(RpcAuthVerifier), auth_data + sizeof(RpcAuthVerifier),
@ -887,11 +958,14 @@ RPC_STATUS RPCRT4_Receive(RpcConnection *Connection, RpcPktHdr **Header,
fail: fail:
RPCRT4_SetThreadCurrentConnection(NULL); RPCRT4_SetThreadCurrentConnection(NULL);
if (CurrentHeader != *Header)
RPCRT4_FreeHeader(CurrentHeader);
if (status != RPC_S_OK) { if (status != RPC_S_OK) {
RPCRT4_FreeHeader(*Header); RPCRT4_FreeHeader(*Header);
*Header = NULL; *Header = NULL;
} }
HeapFree(GetProcessHeap(), 0, auth_data); HeapFree(GetProcessHeap(), 0, auth_data);
HeapFree(GetProcessHeap(), 0, payload);
return status; return status;
} }