diff --git a/dlls/ole32/compobj.c b/dlls/ole32/compobj.c index dcae4f8edcb..4a88390dfa7 100644 --- a/dlls/ole32/compobj.c +++ b/dlls/ole32/compobj.c @@ -3296,6 +3296,26 @@ HRESULT WINAPI CoGetObject(LPCWSTR pszName, BIND_OPTS *pBindOptions, return hr; } +/*********************************************************************** + * CoRegisterChannelHook [OLE32.@] + * + * Registers a process-wide hook that is called during ORPC calls. + * + * PARAMS + * guidExtension [I] GUID of the channel hook to register. + * pChannelHook [I] Channel hook object to register. + * + * RETURNS + * Success: S_OK. + * Failure: HRESULT code. + */ +HRESULT WINAPI CoRegisterChannelHook(REFGUID guidExtension, IChannelHook *pChannelHook) +{ + TRACE("(%s, %p)\n", debugstr_guid(guidExtension), pChannelHook); + + return RPC_RegisterChannelHook(guidExtension, pChannelHook); +} + /*********************************************************************** * DllMain (OLE32.@) */ @@ -3313,6 +3333,7 @@ BOOL WINAPI DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID fImpLoad) case DLL_PROCESS_DETACH: if (TRACE_ON(ole)) CoRevokeMallocSpy(); COMPOBJ_UninitProcess(); + RPC_UnregisterAllChannelHooks(); OLE32_hInstance = 0; break; diff --git a/dlls/ole32/compobj_private.h b/dlls/ole32/compobj_private.h index 9052c383fd0..5c8f5778bde 100644 --- a/dlls/ole32/compobj_private.h +++ b/dlls/ole32/compobj_private.h @@ -219,6 +219,8 @@ HRESULT RPC_RegisterInterface(REFIID riid); void RPC_UnregisterInterface(REFIID riid); void RPC_StartLocalServer(REFCLSID clsid, IStream *stream); HRESULT RPC_GetLocalClassObject(REFCLSID rclsid, REFIID iid, LPVOID *ppv); +HRESULT RPC_RegisterChannelHook(REFGUID rguid, IChannelHook *hook); +void RPC_UnregisterAllChannelHooks(void); /* This function initialize the Running Object Table */ HRESULT WINAPI RunningObjectTableImpl_Initialize(void); diff --git a/dlls/ole32/ole32.spec b/dlls/ole32/ole32.spec index 675edb6f281..d775556b486 100644 --- a/dlls/ole32/ole32.spec +++ b/dlls/ole32/ole32.spec @@ -52,7 +52,7 @@ @ stdcall CoQueryClientBlanket(ptr ptr ptr ptr ptr ptr ptr) @ stdcall CoQueryProxyBlanket(ptr ptr ptr ptr ptr ptr ptr ptr) @ stub CoQueryReleaseObject -@ stub CoRegisterChannelHook +@ stdcall CoRegisterChannelHook(ptr ptr) @ stdcall CoRegisterClassObject(ptr ptr long long ptr) @ stdcall CoRegisterMallocSpy (ptr) @ stdcall CoRegisterMessageFilter(ptr ptr) diff --git a/dlls/ole32/rpc.c b/dlls/ole32/rpc.c index 6c7d65ea8ae..7fe984d17f6 100644 --- a/dlls/ole32/rpc.c +++ b/dlls/ole32/rpc.c @@ -68,6 +68,16 @@ static CRITICAL_SECTION_DEBUG csRegIf_debug = }; static CRITICAL_SECTION csRegIf = { &csRegIf_debug, -1, 0, 0, 0, 0 }; +static struct list channel_hooks = LIST_INIT(channel_hooks); /* (CS csChannelHook) */ +static CRITICAL_SECTION csChannelHook; +static CRITICAL_SECTION_DEBUG csChannelHook_debug = +{ + 0, 0, &csChannelHook, + { &csChannelHook_debug.ProcessLocksList, &csChannelHook_debug.ProcessLocksList }, + 0, 0, { (DWORD_PTR)(__FILE__ ": channel hooks") } +}; +static CRITICAL_SECTION csChannelHook = { &csChannelHook_debug, -1, 0, 0, 0, 0 }; + static WCHAR wszRpcTransport[] = {'n','c','a','l','r','p','c',0}; @@ -127,6 +137,145 @@ typedef struct /* [size_is((size+7)&~7)] */ unsigned char data[1]; } WIRE_ORPC_EXTENT; +struct channel_hook_entry +{ + struct list entry; + GUID id; + IChannelHook *hook; +}; + +struct channel_hook_buffer_data +{ + GUID id; + ULONG extension_size; +}; + + +/* Channel Hook Functions */ + +static ULONG ChannelHooks_ClientGetSize(SChannelHookCallInfo *info, + struct channel_hook_buffer_data **data, unsigned int *hook_count, + ULONG *extension_count) +{ + struct channel_hook_entry *entry; + ULONG total_size = 0; + unsigned int hook_index = 0; + + *hook_count = 0; + *extension_count = 0; + + EnterCriticalSection(&csChannelHook); + + LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry) + (*hook_count)++; + + if (hook_count) + *data = HeapAlloc(GetProcessHeap(), 0, *hook_count * sizeof(struct channel_hook_buffer_data)); + else + *data = NULL; + + LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry) + { + ULONG extension_size = 0; + + IChannelHook_ClientGetSize(entry->hook, &entry->id, &info->iid, &extension_size); + + TRACE("%s: extension_size = %u\n", debugstr_guid(&entry->id), extension_size); + + extension_size = (extension_size+7)&~7; + (*data)[hook_index].id = entry->id; + (*data)[hook_index].extension_size = extension_size; + + /* an extension is only put onto the wire if it has data to write */ + if (extension_size) + { + total_size += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[extension_size]); + (*extension_count)++; + } + + hook_index++; + } + + LeaveCriticalSection(&csChannelHook); + + return total_size; +} + +static unsigned char * ChannelHooks_ClientFillBuffer(SChannelHookCallInfo *info, + unsigned char *buffer, struct channel_hook_buffer_data *data, + unsigned int hook_count) +{ + struct channel_hook_entry *entry; + + EnterCriticalSection(&csChannelHook); + + LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry) + { + unsigned int i; + ULONG extension_size = 0; + WIRE_ORPC_EXTENT *wire_orpc_extent = (WIRE_ORPC_EXTENT *)buffer; + + for (i = 0; i < hook_count; i++) + if (IsEqualGUID(&entry->id, &data[i].id)) + extension_size = data[i].extension_size; + + /* an extension is only put onto the wire if it has data to write */ + if (!extension_size) + continue; + + IChannelHook_ClientFillBuffer(entry->hook, &entry->id, &info->iid, + &extension_size, buffer + FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0])); + + TRACE("%s: extension_size = %u\n", debugstr_guid(&entry->id), extension_size); + + /* FIXME: set unused portion of wire_orpc_extent->data to 0? */ + + wire_orpc_extent->conformance = (extension_size+7)&~7; + wire_orpc_extent->size = extension_size; + memcpy(&wire_orpc_extent->id, &entry->id, sizeof(wire_orpc_extent->id)); + buffer += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[wire_orpc_extent->conformance]); + } + + LeaveCriticalSection(&csChannelHook); + + HeapFree(GetProcessHeap(), 0, data); + + return buffer; +} + +HRESULT RPC_RegisterChannelHook(REFGUID rguid, IChannelHook *hook) +{ + struct channel_hook_entry *entry; + + TRACE("(%s, %p)\n", debugstr_guid(rguid), hook); + + entry = HeapAlloc(GetProcessHeap(), 0, sizeof(*entry)); + if (!entry) + return E_OUTOFMEMORY; + + memcpy(&entry->id, rguid, sizeof(entry->id)); + entry->hook = hook; + IChannelHook_AddRef(hook); + + EnterCriticalSection(&csChannelHook); + list_add_tail(&channel_hooks, &entry->entry); + LeaveCriticalSection(&csChannelHook); + + return S_OK; +} + +void RPC_UnregisterAllChannelHooks(void) +{ + struct channel_hook_entry *cursor; + struct channel_hook_entry *cursor2; + + EnterCriticalSection(&csChannelHook); + LIST_FOR_EACH_ENTRY_SAFE(cursor, cursor2, &channel_hooks, struct channel_hook_entry, entry) + HeapFree(GetProcessHeap(), 0, cursor); + LeaveCriticalSection(&csChannelHook); +} + +/* RPC Channel Buffer Functions */ static HRESULT WINAPI RpcChannelBuffer_QueryInterface(LPRPCCHANNELBUFFER iface, REFIID riid, LPVOID *ppv) { @@ -207,6 +356,11 @@ static HRESULT WINAPI ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface, RPC_STATUS status; ORPCTHIS *orpcthis; struct message_state *message_state; + ULONG extensions_size; + struct channel_hook_buffer_data *channel_hook_data; + unsigned int channel_hook_count; + ULONG extension_count; + SChannelHookCallInfo channel_hook_info; TRACE("(%p)->(%p,%s)\n", This, olemsg, debugstr_guid(riid)); @@ -230,8 +384,24 @@ static HRESULT WINAPI ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface, msg->Handle = This->bind; msg->RpcInterfaceInformation = cif; + channel_hook_info.iid = *riid; + channel_hook_info.cbSize = sizeof(channel_hook_info); + channel_hook_info.uCausality = GUID_NULL; /* FIXME */ + channel_hook_info.dwServerPid = 0; /* FIXME */ + channel_hook_info.iMethod = msg->ProcNum; + channel_hook_info.pObject = NULL; /* only present on server-side */ + + extensions_size = ChannelHooks_ClientGetSize(&channel_hook_info, + &channel_hook_data, &channel_hook_count, &extension_count); + msg->BufferLength += FIELD_OFFSET(ORPCTHIS, extensions) + 4; - + if (extensions_size) + { + msg->BufferLength += FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent) + 2*sizeof(DWORD) + extensions_size; + if (extension_count & 1) + msg->BufferLength += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]); + } + status = I_RpcGetBuffer(msg); message_state->prefix_data_len = 0; @@ -245,14 +415,42 @@ static HRESULT WINAPI ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface, orpcthis->version.MajorVersion = COM_MAJOR_VERSION; orpcthis->version.MinorVersion = COM_MINOR_VERSION; - orpcthis->flags = ORPCF_NULL; + orpcthis->flags = channel_hook_info.dwServerPid ? ORPCF_LOCAL : ORPCF_NULL; orpcthis->reserved1 = 0; - orpcthis->cid = GUID_NULL; /* FIXME */ + orpcthis->cid = channel_hook_info.uCausality; /* NDR representation of orpcthis->extensions */ - *(DWORD *)msg->Buffer = 0; /* FIXME */ + *(DWORD *)msg->Buffer = extensions_size ? 1 : 0; msg->Buffer = (char *)msg->Buffer + sizeof(DWORD); + if (extensions_size) + { + ORPC_EXTENT_ARRAY *orpc_extent_array = msg->Buffer; + orpc_extent_array->size = extension_count; + orpc_extent_array->reserved = 0; + msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent); + /* NDR representation of orpc_extent_array->extent */ + *(DWORD *)msg->Buffer = 1; + msg->Buffer = (char *)msg->Buffer + sizeof(DWORD); + /* NDR representation of [size_is] attribute of orpc_extent_array->extent */ + *(DWORD *)msg->Buffer = (extension_count + 1) & ~1; + msg->Buffer = (char *)msg->Buffer + sizeof(DWORD); + + msg->Buffer = ChannelHooks_ClientFillBuffer(&channel_hook_info, + msg->Buffer, channel_hook_data, channel_hook_count); + + /* we must add a dummy extension if there is an odd extension + * count to meet the contract specified by the size_is attribute */ + if (extension_count & 1) + { + WIRE_ORPC_EXTENT *wire_orpc_extent = msg->Buffer; + wire_orpc_extent->conformance = 0; + memcpy(&wire_orpc_extent->id, &GUID_NULL, sizeof(wire_orpc_extent->id)); + wire_orpc_extent->size = 0; + msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]); + } + } + /* store the prefixed data length so that we can restore the real buffer * pointer in ClientRpcChannelBuffer_SendReceive. */ message_state->prefix_data_len = (char *)msg->Buffer - (char *)orpcthis;