diff --git a/dlls/ole32/rpc.c b/dlls/ole32/rpc.c index 26a139ed99e..6c7d65ea8ae 100644 --- a/dlls/ole32/rpc.c +++ b/dlls/ole32/rpc.c @@ -113,6 +113,21 @@ struct dispatch_params HRESULT hr; /* hresult (out) */ }; +struct message_state +{ + RPC_BINDING_HANDLE binding_handle; + ULONG prefix_data_len; +}; + +typedef struct +{ + ULONG conformance; /* NDR */ + GUID id; + ULONG size; + /* [size_is((size+7)&~7)] */ unsigned char data[1]; +} WIRE_ORPC_EXTENT; + + static HRESULT WINAPI RpcChannelBuffer_QueryInterface(LPRPCCHANNELBUFFER iface, REFIID riid, LPVOID *ppv) { *ppv = NULL; @@ -164,11 +179,21 @@ static HRESULT WINAPI ServerRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface, RpcChannelBuffer *This = (RpcChannelBuffer *)iface; RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg; RPC_STATUS status; + struct message_state *message_state; TRACE("(%p)->(%p,%s)\n", This, olemsg, debugstr_guid(riid)); + message_state = (struct message_state *)msg->Handle; + /* restore the binding handle and the real start of data */ + msg->Handle = message_state->binding_handle; + msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len; + status = I_RpcGetBuffer(msg); + /* save away the message state again */ + msg->Handle = message_state; + message_state->prefix_data_len = 0; + TRACE("-- %ld\n", status); return HRESULT_FROM_WIN32(status); @@ -180,6 +205,8 @@ static HRESULT WINAPI ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface, RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg; RPC_CLIENT_INTERFACE *cif; RPC_STATUS status; + ORPCTHIS *orpcthis; + struct message_state *message_state; TRACE("(%p)->(%p,%s)\n", This, olemsg, debugstr_guid(riid)); @@ -187,17 +214,51 @@ static HRESULT WINAPI ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface, if (!cif) return E_OUTOFMEMORY; + message_state = HeapAlloc(GetProcessHeap(), 0, sizeof(*message_state)); + if (!message_state) + { + HeapFree(GetProcessHeap(), 0, cif); + return E_OUTOFMEMORY; + } + cif->Length = sizeof(RPC_CLIENT_INTERFACE); /* RPC interface ID = COM interface ID */ cif->InterfaceId.SyntaxGUID = *riid; /* COM objects always have a version of 0.0 */ cif->InterfaceId.SyntaxVersion.MajorVersion = 0; cif->InterfaceId.SyntaxVersion.MinorVersion = 0; - msg->RpcInterfaceInformation = cif; msg->Handle = This->bind; + msg->RpcInterfaceInformation = cif; + + msg->BufferLength += FIELD_OFFSET(ORPCTHIS, extensions) + 4; status = I_RpcGetBuffer(msg); + message_state->prefix_data_len = 0; + message_state->binding_handle = This->bind; + msg->Handle = message_state; + + if (status == RPC_S_OK) + { + orpcthis = (ORPCTHIS *)msg->Buffer; + msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPCTHIS, extensions); + + orpcthis->version.MajorVersion = COM_MAJOR_VERSION; + orpcthis->version.MinorVersion = COM_MINOR_VERSION; + orpcthis->flags = ORPCF_NULL; + orpcthis->reserved1 = 0; + orpcthis->cid = GUID_NULL; /* FIXME */ + + /* NDR representation of orpcthis->extensions */ + *(DWORD *)msg->Buffer = 0; /* FIXME */ + msg->Buffer = (char *)msg->Buffer + sizeof(DWORD); + + /* store the prefixed data length so that we can restore the real buffer + * pointer in ClientRpcChannelBuffer_SendReceive. */ + message_state->prefix_data_len = (char *)msg->Buffer - (char *)orpcthis; + msg->BufferLength -= message_state->prefix_data_len; + } + TRACE("-- %ld\n", status); return HRESULT_FROM_WIN32(status); @@ -264,6 +325,7 @@ static HRESULT WINAPI ClientRpcChannelBuffer_SendReceive(LPRPCCHANNELBUFFER ifac struct dispatch_params *params; APARTMENT *apt = NULL; IPID ipid; + struct message_state *message_state; TRACE("(%p) iMethod=%d\n", olemsg, olemsg->iMethod); @@ -278,6 +340,12 @@ static HRESULT WINAPI ClientRpcChannelBuffer_SendReceive(LPRPCCHANNELBUFFER ifac params = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(*params)); if (!params) return E_OUTOFMEMORY; + message_state = (struct message_state *)msg->Handle; + /* restore the binding handle and the real start of data */ + msg->Handle = message_state->binding_handle; + msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len; + msg->BufferLength += message_state->prefix_data_len; + params->msg = olemsg; params->status = RPC_S_OK; params->hr = S_OK; @@ -289,7 +357,7 @@ static HRESULT WINAPI ClientRpcChannelBuffer_SendReceive(LPRPCCHANNELBUFFER ifac * a thread to process the RPC when this function is called indirectly * from DllMain */ - RpcBindingInqObject(msg->Handle, &ipid); + RpcBindingInqObject(message_state->binding_handle, &ipid); hr = ipid_get_dispatch_params(&ipid, &apt, ¶ms->stub, ¶ms->chan); params->handle = ClientRpcChannelBuffer_GetEventHandle(This); if ((hr == S_OK) && !apt->multi_threaded) @@ -337,6 +405,10 @@ static HRESULT WINAPI ClientRpcChannelBuffer_SendReceive(LPRPCCHANNELBUFFER ifac } ClientRpcChannelBuffer_ReleaseEventHandle(This, params->handle); + /* save away the message state again */ + msg->Handle = message_state; + message_state->prefix_data_len = 0; + if (hr == S_OK) hr = params->hr; status = params->status; @@ -364,11 +436,21 @@ static HRESULT WINAPI ServerRpcChannelBuffer_FreeBuffer(LPRPCCHANNELBUFFER iface { RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg; RPC_STATUS status; + struct message_state *message_state; TRACE("(%p)\n", msg); + message_state = (struct message_state *)msg->Handle; + /* restore the binding handle and the real start of data */ + msg->Handle = message_state->binding_handle; + msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len; + msg->BufferLength += message_state->prefix_data_len; + message_state->prefix_data_len = 0; + status = I_RpcFreeBuffer(msg); + msg->Handle = message_state; + TRACE("-- %ld\n", status); return HRESULT_FROM_WIN32(status); @@ -378,13 +460,21 @@ static HRESULT WINAPI ClientRpcChannelBuffer_FreeBuffer(LPRPCCHANNELBUFFER iface { RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg; RPC_STATUS status; + struct message_state *message_state; TRACE("(%p)\n", msg); + message_state = (struct message_state *)msg->Handle; + /* restore the binding handle and the real start of data */ + msg->Handle = message_state->binding_handle; + msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len; + msg->BufferLength += message_state->prefix_data_len; + status = I_RpcFreeBuffer(msg); HeapFree(GetProcessHeap(), 0, msg->RpcInterfaceInformation); msg->RpcInterfaceInformation = NULL; + HeapFree(GetProcessHeap(), 0, message_state); TRACE("-- %ld\n", status); @@ -519,11 +609,134 @@ HRESULT RPC_CreateServerChannel(IRpcChannelBuffer **chan) return S_OK; } +/* unmarshals ORPCTHIS according to NDR rules, but doesn't allocate any memory */ +static HRESULT unmarshal_ORPCTHIS(RPC_MESSAGE *msg, ORPCTHIS *orpcthis, + ORPC_EXTENT_ARRAY *orpc_ext_array, WIRE_ORPC_EXTENT **first_wire_orpc_extent) +{ + const char *end = (char *)msg->Buffer + msg->BufferLength; + + *first_wire_orpc_extent = NULL; + + if (msg->BufferLength < FIELD_OFFSET(ORPCTHIS, extensions) + 4) + { + ERR("invalid buffer length\n"); + return RPC_E_INVALID_HEADER; + } + + memcpy(orpcthis, msg->Buffer, FIELD_OFFSET(ORPCTHIS, extensions)); + msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPCTHIS, extensions); + + if ((const char *)msg->Buffer + sizeof(DWORD) > end) + return RPC_E_INVALID_HEADER; + + if (*(DWORD *)msg->Buffer) + orpcthis->extensions = orpc_ext_array; + else + orpcthis->extensions = NULL; + + msg->Buffer = (char *)msg->Buffer + sizeof(DWORD); + + if (orpcthis->extensions) + { + DWORD pointer_id; + DWORD i; + + memcpy(orpcthis->extensions, msg->Buffer, FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent)); + msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent); + + if ((const char *)msg->Buffer + 2 * sizeof(DWORD) > end) + return RPC_E_INVALID_HEADER; + + pointer_id = *(DWORD *)msg->Buffer; + msg->Buffer = (char *)msg->Buffer + sizeof(DWORD); + orpcthis->extensions->extent = NULL; + + if (pointer_id) + { + WIRE_ORPC_EXTENT *wire_orpc_extent; + + /* conformance */ + if (*(DWORD *)msg->Buffer != ((orpcthis->extensions->size+1)&~1)) + return RPC_S_INVALID_BOUND; + + msg->Buffer = (char *)msg->Buffer + sizeof(DWORD); + + /* arbritary limit for security (don't know what native does) */ + if (orpcthis->extensions->size > 256) + { + ERR("too many extensions: %ld\n", orpcthis->extensions->size); + return RPC_S_INVALID_BOUND; + } + + *first_wire_orpc_extent = wire_orpc_extent = (WIRE_ORPC_EXTENT *)msg->Buffer; + for (i = 0; i < ((orpcthis->extensions->size+1)&~1); i++) + { + if ((const char *)&wire_orpc_extent->data[0] > end) + return RPC_S_INVALID_BOUND; + if (wire_orpc_extent->conformance != ((wire_orpc_extent->size+7)&~7)) + return RPC_S_INVALID_BOUND; + if ((const char *)&wire_orpc_extent->data[wire_orpc_extent->conformance] > end) + return RPC_S_INVALID_BOUND; + TRACE("size %u, guid %s\n", wire_orpc_extent->size, debugstr_guid(&wire_orpc_extent->id)); + wire_orpc_extent = (WIRE_ORPC_EXTENT *)&wire_orpc_extent->data[wire_orpc_extent->conformance]; + } + msg->Buffer = wire_orpc_extent; + } + } + + if ((orpcthis->version.MajorVersion != COM_MAJOR_VERSION) || + (orpcthis->version.MinorVersion > COM_MINOR_VERSION)) + { + ERR("COM version {%d, %d} not supported\n", + orpcthis->version.MajorVersion, orpcthis->version.MinorVersion); + return RPC_E_VERSION_MISMATCH; + } + + if (orpcthis->flags & ~(ORPCF_LOCAL|ORPCF_RESERVED1|ORPCF_RESERVED2|ORPCF_RESERVED3|ORPCF_RESERVED4)) + { + ERR("invalid flags 0x%lx\n", orpcthis->flags & ~(ORPCF_LOCAL|ORPCF_RESERVED1|ORPCF_RESERVED2|ORPCF_RESERVED3|ORPCF_RESERVED4)); + return RPC_E_INVALID_HEADER; + } + + return S_OK; +} void RPC_ExecuteCall(struct dispatch_params *params) { + struct message_state *message_state; + RPC_MESSAGE *msg = (RPC_MESSAGE *)params->msg; + char *original_buffer = msg->Buffer; + ORPCTHIS orpcthis; + ORPC_EXTENT_ARRAY orpc_ext_array; + WIRE_ORPC_EXTENT *first_wire_orpc_extent; + + params->hr = unmarshal_ORPCTHIS(msg, &orpcthis, &orpc_ext_array, &first_wire_orpc_extent); + if (params->hr != S_OK) + goto exit; + + message_state = HeapAlloc(GetProcessHeap(), 0, sizeof(*message_state)); + if (!message_state) + { + params->hr = E_OUTOFMEMORY; + goto exit; + } + + message_state->prefix_data_len = original_buffer - (char *)msg->Buffer; + message_state->binding_handle = msg->Handle; + msg->Handle = message_state; + msg->BufferLength -= message_state->prefix_data_len; + + /* invoke the method */ + params->hr = IRpcStubBuffer_Invoke(params->stub, params->msg, params->chan); + message_state = (struct message_state *)msg->Handle; + msg->Handle = message_state->binding_handle; + msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len; + msg->BufferLength += message_state->prefix_data_len; + HeapFree(GetProcessHeap(), 0, message_state); + +exit: IRpcStubBuffer_Release(params->stub); IRpcChannelBuffer_Release(params->chan); if (params->handle) SetEvent(params->handle);