ole32: Marshal the ORPCTHIS structure prefixed to the client data when doing ORPC calls.

This is done by putting the ORPCTHIS data into the buffer when calling
IRpcChannelBuffer::GetBuffer on the client side and then storing the
amount we increased the buffer in a structure stored in the Handle
field. This is done to present the correct Buffer pointer to the proxy
so that it writes its data after the ORPCTHIS data.

Unmarshal the data on the server side (during RPC_ExecuteCall) and make 
sure the data is consistent according to NDR rules. Also add several 
checks on the unmarshaled data that are specified by the DCOM draft 
specification.
This commit is contained in:
Rob Shearman 2006-12-19 19:35:35 +00:00 committed by Alexandre Julliard
parent e4fc45e0fe
commit 1dc5dec6e6
1 changed files with 215 additions and 2 deletions

View File

@ -113,6 +113,21 @@ struct dispatch_params
HRESULT hr; /* hresult (out) */ HRESULT hr; /* hresult (out) */
}; };
struct message_state
{
RPC_BINDING_HANDLE binding_handle;
ULONG prefix_data_len;
};
typedef struct
{
ULONG conformance; /* NDR */
GUID id;
ULONG size;
/* [size_is((size+7)&~7)] */ unsigned char data[1];
} WIRE_ORPC_EXTENT;
static HRESULT WINAPI RpcChannelBuffer_QueryInterface(LPRPCCHANNELBUFFER iface, REFIID riid, LPVOID *ppv) static HRESULT WINAPI RpcChannelBuffer_QueryInterface(LPRPCCHANNELBUFFER iface, REFIID riid, LPVOID *ppv)
{ {
*ppv = NULL; *ppv = NULL;
@ -164,11 +179,21 @@ static HRESULT WINAPI ServerRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface,
RpcChannelBuffer *This = (RpcChannelBuffer *)iface; RpcChannelBuffer *This = (RpcChannelBuffer *)iface;
RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg; RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg;
RPC_STATUS status; RPC_STATUS status;
struct message_state *message_state;
TRACE("(%p)->(%p,%s)\n", This, olemsg, debugstr_guid(riid)); TRACE("(%p)->(%p,%s)\n", This, olemsg, debugstr_guid(riid));
message_state = (struct message_state *)msg->Handle;
/* restore the binding handle and the real start of data */
msg->Handle = message_state->binding_handle;
msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len;
status = I_RpcGetBuffer(msg); status = I_RpcGetBuffer(msg);
/* save away the message state again */
msg->Handle = message_state;
message_state->prefix_data_len = 0;
TRACE("-- %ld\n", status); TRACE("-- %ld\n", status);
return HRESULT_FROM_WIN32(status); return HRESULT_FROM_WIN32(status);
@ -180,6 +205,8 @@ static HRESULT WINAPI ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface,
RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg; RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg;
RPC_CLIENT_INTERFACE *cif; RPC_CLIENT_INTERFACE *cif;
RPC_STATUS status; RPC_STATUS status;
ORPCTHIS *orpcthis;
struct message_state *message_state;
TRACE("(%p)->(%p,%s)\n", This, olemsg, debugstr_guid(riid)); TRACE("(%p)->(%p,%s)\n", This, olemsg, debugstr_guid(riid));
@ -187,17 +214,51 @@ static HRESULT WINAPI ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface,
if (!cif) if (!cif)
return E_OUTOFMEMORY; return E_OUTOFMEMORY;
message_state = HeapAlloc(GetProcessHeap(), 0, sizeof(*message_state));
if (!message_state)
{
HeapFree(GetProcessHeap(), 0, cif);
return E_OUTOFMEMORY;
}
cif->Length = sizeof(RPC_CLIENT_INTERFACE); cif->Length = sizeof(RPC_CLIENT_INTERFACE);
/* RPC interface ID = COM interface ID */ /* RPC interface ID = COM interface ID */
cif->InterfaceId.SyntaxGUID = *riid; cif->InterfaceId.SyntaxGUID = *riid;
/* COM objects always have a version of 0.0 */ /* COM objects always have a version of 0.0 */
cif->InterfaceId.SyntaxVersion.MajorVersion = 0; cif->InterfaceId.SyntaxVersion.MajorVersion = 0;
cif->InterfaceId.SyntaxVersion.MinorVersion = 0; cif->InterfaceId.SyntaxVersion.MinorVersion = 0;
msg->RpcInterfaceInformation = cif;
msg->Handle = This->bind; msg->Handle = This->bind;
msg->RpcInterfaceInformation = cif;
msg->BufferLength += FIELD_OFFSET(ORPCTHIS, extensions) + 4;
status = I_RpcGetBuffer(msg); status = I_RpcGetBuffer(msg);
message_state->prefix_data_len = 0;
message_state->binding_handle = This->bind;
msg->Handle = message_state;
if (status == RPC_S_OK)
{
orpcthis = (ORPCTHIS *)msg->Buffer;
msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPCTHIS, extensions);
orpcthis->version.MajorVersion = COM_MAJOR_VERSION;
orpcthis->version.MinorVersion = COM_MINOR_VERSION;
orpcthis->flags = ORPCF_NULL;
orpcthis->reserved1 = 0;
orpcthis->cid = GUID_NULL; /* FIXME */
/* NDR representation of orpcthis->extensions */
*(DWORD *)msg->Buffer = 0; /* FIXME */
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
/* store the prefixed data length so that we can restore the real buffer
* pointer in ClientRpcChannelBuffer_SendReceive. */
message_state->prefix_data_len = (char *)msg->Buffer - (char *)orpcthis;
msg->BufferLength -= message_state->prefix_data_len;
}
TRACE("-- %ld\n", status); TRACE("-- %ld\n", status);
return HRESULT_FROM_WIN32(status); return HRESULT_FROM_WIN32(status);
@ -264,6 +325,7 @@ static HRESULT WINAPI ClientRpcChannelBuffer_SendReceive(LPRPCCHANNELBUFFER ifac
struct dispatch_params *params; struct dispatch_params *params;
APARTMENT *apt = NULL; APARTMENT *apt = NULL;
IPID ipid; IPID ipid;
struct message_state *message_state;
TRACE("(%p) iMethod=%d\n", olemsg, olemsg->iMethod); TRACE("(%p) iMethod=%d\n", olemsg, olemsg->iMethod);
@ -278,6 +340,12 @@ static HRESULT WINAPI ClientRpcChannelBuffer_SendReceive(LPRPCCHANNELBUFFER ifac
params = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(*params)); params = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(*params));
if (!params) return E_OUTOFMEMORY; if (!params) return E_OUTOFMEMORY;
message_state = (struct message_state *)msg->Handle;
/* restore the binding handle and the real start of data */
msg->Handle = message_state->binding_handle;
msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len;
msg->BufferLength += message_state->prefix_data_len;
params->msg = olemsg; params->msg = olemsg;
params->status = RPC_S_OK; params->status = RPC_S_OK;
params->hr = S_OK; params->hr = S_OK;
@ -289,7 +357,7 @@ static HRESULT WINAPI ClientRpcChannelBuffer_SendReceive(LPRPCCHANNELBUFFER ifac
* a thread to process the RPC when this function is called indirectly * a thread to process the RPC when this function is called indirectly
* from DllMain */ * from DllMain */
RpcBindingInqObject(msg->Handle, &ipid); RpcBindingInqObject(message_state->binding_handle, &ipid);
hr = ipid_get_dispatch_params(&ipid, &apt, &params->stub, &params->chan); hr = ipid_get_dispatch_params(&ipid, &apt, &params->stub, &params->chan);
params->handle = ClientRpcChannelBuffer_GetEventHandle(This); params->handle = ClientRpcChannelBuffer_GetEventHandle(This);
if ((hr == S_OK) && !apt->multi_threaded) if ((hr == S_OK) && !apt->multi_threaded)
@ -337,6 +405,10 @@ static HRESULT WINAPI ClientRpcChannelBuffer_SendReceive(LPRPCCHANNELBUFFER ifac
} }
ClientRpcChannelBuffer_ReleaseEventHandle(This, params->handle); ClientRpcChannelBuffer_ReleaseEventHandle(This, params->handle);
/* save away the message state again */
msg->Handle = message_state;
message_state->prefix_data_len = 0;
if (hr == S_OK) hr = params->hr; if (hr == S_OK) hr = params->hr;
status = params->status; status = params->status;
@ -364,11 +436,21 @@ static HRESULT WINAPI ServerRpcChannelBuffer_FreeBuffer(LPRPCCHANNELBUFFER iface
{ {
RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg; RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg;
RPC_STATUS status; RPC_STATUS status;
struct message_state *message_state;
TRACE("(%p)\n", msg); TRACE("(%p)\n", msg);
message_state = (struct message_state *)msg->Handle;
/* restore the binding handle and the real start of data */
msg->Handle = message_state->binding_handle;
msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len;
msg->BufferLength += message_state->prefix_data_len;
message_state->prefix_data_len = 0;
status = I_RpcFreeBuffer(msg); status = I_RpcFreeBuffer(msg);
msg->Handle = message_state;
TRACE("-- %ld\n", status); TRACE("-- %ld\n", status);
return HRESULT_FROM_WIN32(status); return HRESULT_FROM_WIN32(status);
@ -378,13 +460,21 @@ static HRESULT WINAPI ClientRpcChannelBuffer_FreeBuffer(LPRPCCHANNELBUFFER iface
{ {
RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg; RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg;
RPC_STATUS status; RPC_STATUS status;
struct message_state *message_state;
TRACE("(%p)\n", msg); TRACE("(%p)\n", msg);
message_state = (struct message_state *)msg->Handle;
/* restore the binding handle and the real start of data */
msg->Handle = message_state->binding_handle;
msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len;
msg->BufferLength += message_state->prefix_data_len;
status = I_RpcFreeBuffer(msg); status = I_RpcFreeBuffer(msg);
HeapFree(GetProcessHeap(), 0, msg->RpcInterfaceInformation); HeapFree(GetProcessHeap(), 0, msg->RpcInterfaceInformation);
msg->RpcInterfaceInformation = NULL; msg->RpcInterfaceInformation = NULL;
HeapFree(GetProcessHeap(), 0, message_state);
TRACE("-- %ld\n", status); TRACE("-- %ld\n", status);
@ -519,11 +609,134 @@ HRESULT RPC_CreateServerChannel(IRpcChannelBuffer **chan)
return S_OK; return S_OK;
} }
/* unmarshals ORPCTHIS according to NDR rules, but doesn't allocate any memory */
static HRESULT unmarshal_ORPCTHIS(RPC_MESSAGE *msg, ORPCTHIS *orpcthis,
ORPC_EXTENT_ARRAY *orpc_ext_array, WIRE_ORPC_EXTENT **first_wire_orpc_extent)
{
const char *end = (char *)msg->Buffer + msg->BufferLength;
*first_wire_orpc_extent = NULL;
if (msg->BufferLength < FIELD_OFFSET(ORPCTHIS, extensions) + 4)
{
ERR("invalid buffer length\n");
return RPC_E_INVALID_HEADER;
}
memcpy(orpcthis, msg->Buffer, FIELD_OFFSET(ORPCTHIS, extensions));
msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPCTHIS, extensions);
if ((const char *)msg->Buffer + sizeof(DWORD) > end)
return RPC_E_INVALID_HEADER;
if (*(DWORD *)msg->Buffer)
orpcthis->extensions = orpc_ext_array;
else
orpcthis->extensions = NULL;
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
if (orpcthis->extensions)
{
DWORD pointer_id;
DWORD i;
memcpy(orpcthis->extensions, msg->Buffer, FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent));
msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent);
if ((const char *)msg->Buffer + 2 * sizeof(DWORD) > end)
return RPC_E_INVALID_HEADER;
pointer_id = *(DWORD *)msg->Buffer;
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
orpcthis->extensions->extent = NULL;
if (pointer_id)
{
WIRE_ORPC_EXTENT *wire_orpc_extent;
/* conformance */
if (*(DWORD *)msg->Buffer != ((orpcthis->extensions->size+1)&~1))
return RPC_S_INVALID_BOUND;
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
/* arbritary limit for security (don't know what native does) */
if (orpcthis->extensions->size > 256)
{
ERR("too many extensions: %ld\n", orpcthis->extensions->size);
return RPC_S_INVALID_BOUND;
}
*first_wire_orpc_extent = wire_orpc_extent = (WIRE_ORPC_EXTENT *)msg->Buffer;
for (i = 0; i < ((orpcthis->extensions->size+1)&~1); i++)
{
if ((const char *)&wire_orpc_extent->data[0] > end)
return RPC_S_INVALID_BOUND;
if (wire_orpc_extent->conformance != ((wire_orpc_extent->size+7)&~7))
return RPC_S_INVALID_BOUND;
if ((const char *)&wire_orpc_extent->data[wire_orpc_extent->conformance] > end)
return RPC_S_INVALID_BOUND;
TRACE("size %u, guid %s\n", wire_orpc_extent->size, debugstr_guid(&wire_orpc_extent->id));
wire_orpc_extent = (WIRE_ORPC_EXTENT *)&wire_orpc_extent->data[wire_orpc_extent->conformance];
}
msg->Buffer = wire_orpc_extent;
}
}
if ((orpcthis->version.MajorVersion != COM_MAJOR_VERSION) ||
(orpcthis->version.MinorVersion > COM_MINOR_VERSION))
{
ERR("COM version {%d, %d} not supported\n",
orpcthis->version.MajorVersion, orpcthis->version.MinorVersion);
return RPC_E_VERSION_MISMATCH;
}
if (orpcthis->flags & ~(ORPCF_LOCAL|ORPCF_RESERVED1|ORPCF_RESERVED2|ORPCF_RESERVED3|ORPCF_RESERVED4))
{
ERR("invalid flags 0x%lx\n", orpcthis->flags & ~(ORPCF_LOCAL|ORPCF_RESERVED1|ORPCF_RESERVED2|ORPCF_RESERVED3|ORPCF_RESERVED4));
return RPC_E_INVALID_HEADER;
}
return S_OK;
}
void RPC_ExecuteCall(struct dispatch_params *params) void RPC_ExecuteCall(struct dispatch_params *params)
{ {
struct message_state *message_state;
RPC_MESSAGE *msg = (RPC_MESSAGE *)params->msg;
char *original_buffer = msg->Buffer;
ORPCTHIS orpcthis;
ORPC_EXTENT_ARRAY orpc_ext_array;
WIRE_ORPC_EXTENT *first_wire_orpc_extent;
params->hr = unmarshal_ORPCTHIS(msg, &orpcthis, &orpc_ext_array, &first_wire_orpc_extent);
if (params->hr != S_OK)
goto exit;
message_state = HeapAlloc(GetProcessHeap(), 0, sizeof(*message_state));
if (!message_state)
{
params->hr = E_OUTOFMEMORY;
goto exit;
}
message_state->prefix_data_len = original_buffer - (char *)msg->Buffer;
message_state->binding_handle = msg->Handle;
msg->Handle = message_state;
msg->BufferLength -= message_state->prefix_data_len;
/* invoke the method */
params->hr = IRpcStubBuffer_Invoke(params->stub, params->msg, params->chan); params->hr = IRpcStubBuffer_Invoke(params->stub, params->msg, params->chan);
message_state = (struct message_state *)msg->Handle;
msg->Handle = message_state->binding_handle;
msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len;
msg->BufferLength += message_state->prefix_data_len;
HeapFree(GetProcessHeap(), 0, message_state);
exit:
IRpcStubBuffer_Release(params->stub); IRpcStubBuffer_Release(params->stub);
IRpcChannelBuffer_Release(params->chan); IRpcChannelBuffer_Release(params->chan);
if (params->handle) SetEvent(params->handle); if (params->handle) SetEvent(params->handle);