ole32: Implement CoRegisterChannelHook and call channel hook methods on the client side.
This commit is contained in:
parent
1dc5dec6e6
commit
5788ee9f05
|
@ -3296,6 +3296,26 @@ HRESULT WINAPI CoGetObject(LPCWSTR pszName, BIND_OPTS *pBindOptions,
|
|||
return hr;
|
||||
}
|
||||
|
||||
/***********************************************************************
|
||||
* CoRegisterChannelHook [OLE32.@]
|
||||
*
|
||||
* Registers a process-wide hook that is called during ORPC calls.
|
||||
*
|
||||
* PARAMS
|
||||
* guidExtension [I] GUID of the channel hook to register.
|
||||
* pChannelHook [I] Channel hook object to register.
|
||||
*
|
||||
* RETURNS
|
||||
* Success: S_OK.
|
||||
* Failure: HRESULT code.
|
||||
*/
|
||||
HRESULT WINAPI CoRegisterChannelHook(REFGUID guidExtension, IChannelHook *pChannelHook)
|
||||
{
|
||||
TRACE("(%s, %p)\n", debugstr_guid(guidExtension), pChannelHook);
|
||||
|
||||
return RPC_RegisterChannelHook(guidExtension, pChannelHook);
|
||||
}
|
||||
|
||||
/***********************************************************************
|
||||
* DllMain (OLE32.@)
|
||||
*/
|
||||
|
@ -3313,6 +3333,7 @@ BOOL WINAPI DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID fImpLoad)
|
|||
case DLL_PROCESS_DETACH:
|
||||
if (TRACE_ON(ole)) CoRevokeMallocSpy();
|
||||
COMPOBJ_UninitProcess();
|
||||
RPC_UnregisterAllChannelHooks();
|
||||
OLE32_hInstance = 0;
|
||||
break;
|
||||
|
||||
|
|
|
@ -219,6 +219,8 @@ HRESULT RPC_RegisterInterface(REFIID riid);
|
|||
void RPC_UnregisterInterface(REFIID riid);
|
||||
void RPC_StartLocalServer(REFCLSID clsid, IStream *stream);
|
||||
HRESULT RPC_GetLocalClassObject(REFCLSID rclsid, REFIID iid, LPVOID *ppv);
|
||||
HRESULT RPC_RegisterChannelHook(REFGUID rguid, IChannelHook *hook);
|
||||
void RPC_UnregisterAllChannelHooks(void);
|
||||
|
||||
/* This function initialize the Running Object Table */
|
||||
HRESULT WINAPI RunningObjectTableImpl_Initialize(void);
|
||||
|
|
|
@ -52,7 +52,7 @@
|
|||
@ stdcall CoQueryClientBlanket(ptr ptr ptr ptr ptr ptr ptr)
|
||||
@ stdcall CoQueryProxyBlanket(ptr ptr ptr ptr ptr ptr ptr ptr)
|
||||
@ stub CoQueryReleaseObject
|
||||
@ stub CoRegisterChannelHook
|
||||
@ stdcall CoRegisterChannelHook(ptr ptr)
|
||||
@ stdcall CoRegisterClassObject(ptr ptr long long ptr)
|
||||
@ stdcall CoRegisterMallocSpy (ptr)
|
||||
@ stdcall CoRegisterMessageFilter(ptr ptr)
|
||||
|
|
206
dlls/ole32/rpc.c
206
dlls/ole32/rpc.c
|
@ -68,6 +68,16 @@ static CRITICAL_SECTION_DEBUG csRegIf_debug =
|
|||
};
|
||||
static CRITICAL_SECTION csRegIf = { &csRegIf_debug, -1, 0, 0, 0, 0 };
|
||||
|
||||
static struct list channel_hooks = LIST_INIT(channel_hooks); /* (CS csChannelHook) */
|
||||
static CRITICAL_SECTION csChannelHook;
|
||||
static CRITICAL_SECTION_DEBUG csChannelHook_debug =
|
||||
{
|
||||
0, 0, &csChannelHook,
|
||||
{ &csChannelHook_debug.ProcessLocksList, &csChannelHook_debug.ProcessLocksList },
|
||||
0, 0, { (DWORD_PTR)(__FILE__ ": channel hooks") }
|
||||
};
|
||||
static CRITICAL_SECTION csChannelHook = { &csChannelHook_debug, -1, 0, 0, 0, 0 };
|
||||
|
||||
static WCHAR wszRpcTransport[] = {'n','c','a','l','r','p','c',0};
|
||||
|
||||
|
||||
|
@ -127,6 +137,145 @@ typedef struct
|
|||
/* [size_is((size+7)&~7)] */ unsigned char data[1];
|
||||
} WIRE_ORPC_EXTENT;
|
||||
|
||||
struct channel_hook_entry
|
||||
{
|
||||
struct list entry;
|
||||
GUID id;
|
||||
IChannelHook *hook;
|
||||
};
|
||||
|
||||
struct channel_hook_buffer_data
|
||||
{
|
||||
GUID id;
|
||||
ULONG extension_size;
|
||||
};
|
||||
|
||||
|
||||
/* Channel Hook Functions */
|
||||
|
||||
static ULONG ChannelHooks_ClientGetSize(SChannelHookCallInfo *info,
|
||||
struct channel_hook_buffer_data **data, unsigned int *hook_count,
|
||||
ULONG *extension_count)
|
||||
{
|
||||
struct channel_hook_entry *entry;
|
||||
ULONG total_size = 0;
|
||||
unsigned int hook_index = 0;
|
||||
|
||||
*hook_count = 0;
|
||||
*extension_count = 0;
|
||||
|
||||
EnterCriticalSection(&csChannelHook);
|
||||
|
||||
LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
|
||||
(*hook_count)++;
|
||||
|
||||
if (hook_count)
|
||||
*data = HeapAlloc(GetProcessHeap(), 0, *hook_count * sizeof(struct channel_hook_buffer_data));
|
||||
else
|
||||
*data = NULL;
|
||||
|
||||
LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
|
||||
{
|
||||
ULONG extension_size = 0;
|
||||
|
||||
IChannelHook_ClientGetSize(entry->hook, &entry->id, &info->iid, &extension_size);
|
||||
|
||||
TRACE("%s: extension_size = %u\n", debugstr_guid(&entry->id), extension_size);
|
||||
|
||||
extension_size = (extension_size+7)&~7;
|
||||
(*data)[hook_index].id = entry->id;
|
||||
(*data)[hook_index].extension_size = extension_size;
|
||||
|
||||
/* an extension is only put onto the wire if it has data to write */
|
||||
if (extension_size)
|
||||
{
|
||||
total_size += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[extension_size]);
|
||||
(*extension_count)++;
|
||||
}
|
||||
|
||||
hook_index++;
|
||||
}
|
||||
|
||||
LeaveCriticalSection(&csChannelHook);
|
||||
|
||||
return total_size;
|
||||
}
|
||||
|
||||
static unsigned char * ChannelHooks_ClientFillBuffer(SChannelHookCallInfo *info,
|
||||
unsigned char *buffer, struct channel_hook_buffer_data *data,
|
||||
unsigned int hook_count)
|
||||
{
|
||||
struct channel_hook_entry *entry;
|
||||
|
||||
EnterCriticalSection(&csChannelHook);
|
||||
|
||||
LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
|
||||
{
|
||||
unsigned int i;
|
||||
ULONG extension_size = 0;
|
||||
WIRE_ORPC_EXTENT *wire_orpc_extent = (WIRE_ORPC_EXTENT *)buffer;
|
||||
|
||||
for (i = 0; i < hook_count; i++)
|
||||
if (IsEqualGUID(&entry->id, &data[i].id))
|
||||
extension_size = data[i].extension_size;
|
||||
|
||||
/* an extension is only put onto the wire if it has data to write */
|
||||
if (!extension_size)
|
||||
continue;
|
||||
|
||||
IChannelHook_ClientFillBuffer(entry->hook, &entry->id, &info->iid,
|
||||
&extension_size, buffer + FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]));
|
||||
|
||||
TRACE("%s: extension_size = %u\n", debugstr_guid(&entry->id), extension_size);
|
||||
|
||||
/* FIXME: set unused portion of wire_orpc_extent->data to 0? */
|
||||
|
||||
wire_orpc_extent->conformance = (extension_size+7)&~7;
|
||||
wire_orpc_extent->size = extension_size;
|
||||
memcpy(&wire_orpc_extent->id, &entry->id, sizeof(wire_orpc_extent->id));
|
||||
buffer += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[wire_orpc_extent->conformance]);
|
||||
}
|
||||
|
||||
LeaveCriticalSection(&csChannelHook);
|
||||
|
||||
HeapFree(GetProcessHeap(), 0, data);
|
||||
|
||||
return buffer;
|
||||
}
|
||||
|
||||
HRESULT RPC_RegisterChannelHook(REFGUID rguid, IChannelHook *hook)
|
||||
{
|
||||
struct channel_hook_entry *entry;
|
||||
|
||||
TRACE("(%s, %p)\n", debugstr_guid(rguid), hook);
|
||||
|
||||
entry = HeapAlloc(GetProcessHeap(), 0, sizeof(*entry));
|
||||
if (!entry)
|
||||
return E_OUTOFMEMORY;
|
||||
|
||||
memcpy(&entry->id, rguid, sizeof(entry->id));
|
||||
entry->hook = hook;
|
||||
IChannelHook_AddRef(hook);
|
||||
|
||||
EnterCriticalSection(&csChannelHook);
|
||||
list_add_tail(&channel_hooks, &entry->entry);
|
||||
LeaveCriticalSection(&csChannelHook);
|
||||
|
||||
return S_OK;
|
||||
}
|
||||
|
||||
void RPC_UnregisterAllChannelHooks(void)
|
||||
{
|
||||
struct channel_hook_entry *cursor;
|
||||
struct channel_hook_entry *cursor2;
|
||||
|
||||
EnterCriticalSection(&csChannelHook);
|
||||
LIST_FOR_EACH_ENTRY_SAFE(cursor, cursor2, &channel_hooks, struct channel_hook_entry, entry)
|
||||
HeapFree(GetProcessHeap(), 0, cursor);
|
||||
LeaveCriticalSection(&csChannelHook);
|
||||
}
|
||||
|
||||
/* RPC Channel Buffer Functions */
|
||||
|
||||
static HRESULT WINAPI RpcChannelBuffer_QueryInterface(LPRPCCHANNELBUFFER iface, REFIID riid, LPVOID *ppv)
|
||||
{
|
||||
|
@ -207,6 +356,11 @@ static HRESULT WINAPI ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface,
|
|||
RPC_STATUS status;
|
||||
ORPCTHIS *orpcthis;
|
||||
struct message_state *message_state;
|
||||
ULONG extensions_size;
|
||||
struct channel_hook_buffer_data *channel_hook_data;
|
||||
unsigned int channel_hook_count;
|
||||
ULONG extension_count;
|
||||
SChannelHookCallInfo channel_hook_info;
|
||||
|
||||
TRACE("(%p)->(%p,%s)\n", This, olemsg, debugstr_guid(riid));
|
||||
|
||||
|
@ -230,8 +384,24 @@ static HRESULT WINAPI ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface,
|
|||
msg->Handle = This->bind;
|
||||
msg->RpcInterfaceInformation = cif;
|
||||
|
||||
channel_hook_info.iid = *riid;
|
||||
channel_hook_info.cbSize = sizeof(channel_hook_info);
|
||||
channel_hook_info.uCausality = GUID_NULL; /* FIXME */
|
||||
channel_hook_info.dwServerPid = 0; /* FIXME */
|
||||
channel_hook_info.iMethod = msg->ProcNum;
|
||||
channel_hook_info.pObject = NULL; /* only present on server-side */
|
||||
|
||||
extensions_size = ChannelHooks_ClientGetSize(&channel_hook_info,
|
||||
&channel_hook_data, &channel_hook_count, &extension_count);
|
||||
|
||||
msg->BufferLength += FIELD_OFFSET(ORPCTHIS, extensions) + 4;
|
||||
|
||||
if (extensions_size)
|
||||
{
|
||||
msg->BufferLength += FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent) + 2*sizeof(DWORD) + extensions_size;
|
||||
if (extension_count & 1)
|
||||
msg->BufferLength += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]);
|
||||
}
|
||||
|
||||
status = I_RpcGetBuffer(msg);
|
||||
|
||||
message_state->prefix_data_len = 0;
|
||||
|
@ -245,14 +415,42 @@ static HRESULT WINAPI ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface,
|
|||
|
||||
orpcthis->version.MajorVersion = COM_MAJOR_VERSION;
|
||||
orpcthis->version.MinorVersion = COM_MINOR_VERSION;
|
||||
orpcthis->flags = ORPCF_NULL;
|
||||
orpcthis->flags = channel_hook_info.dwServerPid ? ORPCF_LOCAL : ORPCF_NULL;
|
||||
orpcthis->reserved1 = 0;
|
||||
orpcthis->cid = GUID_NULL; /* FIXME */
|
||||
orpcthis->cid = channel_hook_info.uCausality;
|
||||
|
||||
/* NDR representation of orpcthis->extensions */
|
||||
*(DWORD *)msg->Buffer = 0; /* FIXME */
|
||||
*(DWORD *)msg->Buffer = extensions_size ? 1 : 0;
|
||||
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
|
||||
|
||||
if (extensions_size)
|
||||
{
|
||||
ORPC_EXTENT_ARRAY *orpc_extent_array = msg->Buffer;
|
||||
orpc_extent_array->size = extension_count;
|
||||
orpc_extent_array->reserved = 0;
|
||||
msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent);
|
||||
/* NDR representation of orpc_extent_array->extent */
|
||||
*(DWORD *)msg->Buffer = 1;
|
||||
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
|
||||
/* NDR representation of [size_is] attribute of orpc_extent_array->extent */
|
||||
*(DWORD *)msg->Buffer = (extension_count + 1) & ~1;
|
||||
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
|
||||
|
||||
msg->Buffer = ChannelHooks_ClientFillBuffer(&channel_hook_info,
|
||||
msg->Buffer, channel_hook_data, channel_hook_count);
|
||||
|
||||
/* we must add a dummy extension if there is an odd extension
|
||||
* count to meet the contract specified by the size_is attribute */
|
||||
if (extension_count & 1)
|
||||
{
|
||||
WIRE_ORPC_EXTENT *wire_orpc_extent = msg->Buffer;
|
||||
wire_orpc_extent->conformance = 0;
|
||||
memcpy(&wire_orpc_extent->id, &GUID_NULL, sizeof(wire_orpc_extent->id));
|
||||
wire_orpc_extent->size = 0;
|
||||
msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]);
|
||||
}
|
||||
}
|
||||
|
||||
/* store the prefixed data length so that we can restore the real buffer
|
||||
* pointer in ClientRpcChannelBuffer_SendReceive. */
|
||||
message_state->prefix_data_len = (char *)msg->Buffer - (char *)orpcthis;
|
||||
|
|
Loading…
Reference in New Issue