ole32: Implement CoRegisterChannelHook and call channel hook methods on the client side.
This commit is contained in:
parent
1dc5dec6e6
commit
5788ee9f05
|
@ -3296,6 +3296,26 @@ HRESULT WINAPI CoGetObject(LPCWSTR pszName, BIND_OPTS *pBindOptions,
|
||||||
return hr;
|
return hr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/***********************************************************************
|
||||||
|
* CoRegisterChannelHook [OLE32.@]
|
||||||
|
*
|
||||||
|
* Registers a process-wide hook that is called during ORPC calls.
|
||||||
|
*
|
||||||
|
* PARAMS
|
||||||
|
* guidExtension [I] GUID of the channel hook to register.
|
||||||
|
* pChannelHook [I] Channel hook object to register.
|
||||||
|
*
|
||||||
|
* RETURNS
|
||||||
|
* Success: S_OK.
|
||||||
|
* Failure: HRESULT code.
|
||||||
|
*/
|
||||||
|
HRESULT WINAPI CoRegisterChannelHook(REFGUID guidExtension, IChannelHook *pChannelHook)
|
||||||
|
{
|
||||||
|
TRACE("(%s, %p)\n", debugstr_guid(guidExtension), pChannelHook);
|
||||||
|
|
||||||
|
return RPC_RegisterChannelHook(guidExtension, pChannelHook);
|
||||||
|
}
|
||||||
|
|
||||||
/***********************************************************************
|
/***********************************************************************
|
||||||
* DllMain (OLE32.@)
|
* DllMain (OLE32.@)
|
||||||
*/
|
*/
|
||||||
|
@ -3313,6 +3333,7 @@ BOOL WINAPI DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID fImpLoad)
|
||||||
case DLL_PROCESS_DETACH:
|
case DLL_PROCESS_DETACH:
|
||||||
if (TRACE_ON(ole)) CoRevokeMallocSpy();
|
if (TRACE_ON(ole)) CoRevokeMallocSpy();
|
||||||
COMPOBJ_UninitProcess();
|
COMPOBJ_UninitProcess();
|
||||||
|
RPC_UnregisterAllChannelHooks();
|
||||||
OLE32_hInstance = 0;
|
OLE32_hInstance = 0;
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
|
|
@ -219,6 +219,8 @@ HRESULT RPC_RegisterInterface(REFIID riid);
|
||||||
void RPC_UnregisterInterface(REFIID riid);
|
void RPC_UnregisterInterface(REFIID riid);
|
||||||
void RPC_StartLocalServer(REFCLSID clsid, IStream *stream);
|
void RPC_StartLocalServer(REFCLSID clsid, IStream *stream);
|
||||||
HRESULT RPC_GetLocalClassObject(REFCLSID rclsid, REFIID iid, LPVOID *ppv);
|
HRESULT RPC_GetLocalClassObject(REFCLSID rclsid, REFIID iid, LPVOID *ppv);
|
||||||
|
HRESULT RPC_RegisterChannelHook(REFGUID rguid, IChannelHook *hook);
|
||||||
|
void RPC_UnregisterAllChannelHooks(void);
|
||||||
|
|
||||||
/* This function initialize the Running Object Table */
|
/* This function initialize the Running Object Table */
|
||||||
HRESULT WINAPI RunningObjectTableImpl_Initialize(void);
|
HRESULT WINAPI RunningObjectTableImpl_Initialize(void);
|
||||||
|
|
|
@ -52,7 +52,7 @@
|
||||||
@ stdcall CoQueryClientBlanket(ptr ptr ptr ptr ptr ptr ptr)
|
@ stdcall CoQueryClientBlanket(ptr ptr ptr ptr ptr ptr ptr)
|
||||||
@ stdcall CoQueryProxyBlanket(ptr ptr ptr ptr ptr ptr ptr ptr)
|
@ stdcall CoQueryProxyBlanket(ptr ptr ptr ptr ptr ptr ptr ptr)
|
||||||
@ stub CoQueryReleaseObject
|
@ stub CoQueryReleaseObject
|
||||||
@ stub CoRegisterChannelHook
|
@ stdcall CoRegisterChannelHook(ptr ptr)
|
||||||
@ stdcall CoRegisterClassObject(ptr ptr long long ptr)
|
@ stdcall CoRegisterClassObject(ptr ptr long long ptr)
|
||||||
@ stdcall CoRegisterMallocSpy (ptr)
|
@ stdcall CoRegisterMallocSpy (ptr)
|
||||||
@ stdcall CoRegisterMessageFilter(ptr ptr)
|
@ stdcall CoRegisterMessageFilter(ptr ptr)
|
||||||
|
|
204
dlls/ole32/rpc.c
204
dlls/ole32/rpc.c
|
@ -68,6 +68,16 @@ static CRITICAL_SECTION_DEBUG csRegIf_debug =
|
||||||
};
|
};
|
||||||
static CRITICAL_SECTION csRegIf = { &csRegIf_debug, -1, 0, 0, 0, 0 };
|
static CRITICAL_SECTION csRegIf = { &csRegIf_debug, -1, 0, 0, 0, 0 };
|
||||||
|
|
||||||
|
static struct list channel_hooks = LIST_INIT(channel_hooks); /* (CS csChannelHook) */
|
||||||
|
static CRITICAL_SECTION csChannelHook;
|
||||||
|
static CRITICAL_SECTION_DEBUG csChannelHook_debug =
|
||||||
|
{
|
||||||
|
0, 0, &csChannelHook,
|
||||||
|
{ &csChannelHook_debug.ProcessLocksList, &csChannelHook_debug.ProcessLocksList },
|
||||||
|
0, 0, { (DWORD_PTR)(__FILE__ ": channel hooks") }
|
||||||
|
};
|
||||||
|
static CRITICAL_SECTION csChannelHook = { &csChannelHook_debug, -1, 0, 0, 0, 0 };
|
||||||
|
|
||||||
static WCHAR wszRpcTransport[] = {'n','c','a','l','r','p','c',0};
|
static WCHAR wszRpcTransport[] = {'n','c','a','l','r','p','c',0};
|
||||||
|
|
||||||
|
|
||||||
|
@ -127,6 +137,145 @@ typedef struct
|
||||||
/* [size_is((size+7)&~7)] */ unsigned char data[1];
|
/* [size_is((size+7)&~7)] */ unsigned char data[1];
|
||||||
} WIRE_ORPC_EXTENT;
|
} WIRE_ORPC_EXTENT;
|
||||||
|
|
||||||
|
struct channel_hook_entry
|
||||||
|
{
|
||||||
|
struct list entry;
|
||||||
|
GUID id;
|
||||||
|
IChannelHook *hook;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct channel_hook_buffer_data
|
||||||
|
{
|
||||||
|
GUID id;
|
||||||
|
ULONG extension_size;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
/* Channel Hook Functions */
|
||||||
|
|
||||||
|
static ULONG ChannelHooks_ClientGetSize(SChannelHookCallInfo *info,
|
||||||
|
struct channel_hook_buffer_data **data, unsigned int *hook_count,
|
||||||
|
ULONG *extension_count)
|
||||||
|
{
|
||||||
|
struct channel_hook_entry *entry;
|
||||||
|
ULONG total_size = 0;
|
||||||
|
unsigned int hook_index = 0;
|
||||||
|
|
||||||
|
*hook_count = 0;
|
||||||
|
*extension_count = 0;
|
||||||
|
|
||||||
|
EnterCriticalSection(&csChannelHook);
|
||||||
|
|
||||||
|
LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
|
||||||
|
(*hook_count)++;
|
||||||
|
|
||||||
|
if (hook_count)
|
||||||
|
*data = HeapAlloc(GetProcessHeap(), 0, *hook_count * sizeof(struct channel_hook_buffer_data));
|
||||||
|
else
|
||||||
|
*data = NULL;
|
||||||
|
|
||||||
|
LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
|
||||||
|
{
|
||||||
|
ULONG extension_size = 0;
|
||||||
|
|
||||||
|
IChannelHook_ClientGetSize(entry->hook, &entry->id, &info->iid, &extension_size);
|
||||||
|
|
||||||
|
TRACE("%s: extension_size = %u\n", debugstr_guid(&entry->id), extension_size);
|
||||||
|
|
||||||
|
extension_size = (extension_size+7)&~7;
|
||||||
|
(*data)[hook_index].id = entry->id;
|
||||||
|
(*data)[hook_index].extension_size = extension_size;
|
||||||
|
|
||||||
|
/* an extension is only put onto the wire if it has data to write */
|
||||||
|
if (extension_size)
|
||||||
|
{
|
||||||
|
total_size += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[extension_size]);
|
||||||
|
(*extension_count)++;
|
||||||
|
}
|
||||||
|
|
||||||
|
hook_index++;
|
||||||
|
}
|
||||||
|
|
||||||
|
LeaveCriticalSection(&csChannelHook);
|
||||||
|
|
||||||
|
return total_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
static unsigned char * ChannelHooks_ClientFillBuffer(SChannelHookCallInfo *info,
|
||||||
|
unsigned char *buffer, struct channel_hook_buffer_data *data,
|
||||||
|
unsigned int hook_count)
|
||||||
|
{
|
||||||
|
struct channel_hook_entry *entry;
|
||||||
|
|
||||||
|
EnterCriticalSection(&csChannelHook);
|
||||||
|
|
||||||
|
LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
|
||||||
|
{
|
||||||
|
unsigned int i;
|
||||||
|
ULONG extension_size = 0;
|
||||||
|
WIRE_ORPC_EXTENT *wire_orpc_extent = (WIRE_ORPC_EXTENT *)buffer;
|
||||||
|
|
||||||
|
for (i = 0; i < hook_count; i++)
|
||||||
|
if (IsEqualGUID(&entry->id, &data[i].id))
|
||||||
|
extension_size = data[i].extension_size;
|
||||||
|
|
||||||
|
/* an extension is only put onto the wire if it has data to write */
|
||||||
|
if (!extension_size)
|
||||||
|
continue;
|
||||||
|
|
||||||
|
IChannelHook_ClientFillBuffer(entry->hook, &entry->id, &info->iid,
|
||||||
|
&extension_size, buffer + FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]));
|
||||||
|
|
||||||
|
TRACE("%s: extension_size = %u\n", debugstr_guid(&entry->id), extension_size);
|
||||||
|
|
||||||
|
/* FIXME: set unused portion of wire_orpc_extent->data to 0? */
|
||||||
|
|
||||||
|
wire_orpc_extent->conformance = (extension_size+7)&~7;
|
||||||
|
wire_orpc_extent->size = extension_size;
|
||||||
|
memcpy(&wire_orpc_extent->id, &entry->id, sizeof(wire_orpc_extent->id));
|
||||||
|
buffer += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[wire_orpc_extent->conformance]);
|
||||||
|
}
|
||||||
|
|
||||||
|
LeaveCriticalSection(&csChannelHook);
|
||||||
|
|
||||||
|
HeapFree(GetProcessHeap(), 0, data);
|
||||||
|
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
HRESULT RPC_RegisterChannelHook(REFGUID rguid, IChannelHook *hook)
|
||||||
|
{
|
||||||
|
struct channel_hook_entry *entry;
|
||||||
|
|
||||||
|
TRACE("(%s, %p)\n", debugstr_guid(rguid), hook);
|
||||||
|
|
||||||
|
entry = HeapAlloc(GetProcessHeap(), 0, sizeof(*entry));
|
||||||
|
if (!entry)
|
||||||
|
return E_OUTOFMEMORY;
|
||||||
|
|
||||||
|
memcpy(&entry->id, rguid, sizeof(entry->id));
|
||||||
|
entry->hook = hook;
|
||||||
|
IChannelHook_AddRef(hook);
|
||||||
|
|
||||||
|
EnterCriticalSection(&csChannelHook);
|
||||||
|
list_add_tail(&channel_hooks, &entry->entry);
|
||||||
|
LeaveCriticalSection(&csChannelHook);
|
||||||
|
|
||||||
|
return S_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RPC_UnregisterAllChannelHooks(void)
|
||||||
|
{
|
||||||
|
struct channel_hook_entry *cursor;
|
||||||
|
struct channel_hook_entry *cursor2;
|
||||||
|
|
||||||
|
EnterCriticalSection(&csChannelHook);
|
||||||
|
LIST_FOR_EACH_ENTRY_SAFE(cursor, cursor2, &channel_hooks, struct channel_hook_entry, entry)
|
||||||
|
HeapFree(GetProcessHeap(), 0, cursor);
|
||||||
|
LeaveCriticalSection(&csChannelHook);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* RPC Channel Buffer Functions */
|
||||||
|
|
||||||
static HRESULT WINAPI RpcChannelBuffer_QueryInterface(LPRPCCHANNELBUFFER iface, REFIID riid, LPVOID *ppv)
|
static HRESULT WINAPI RpcChannelBuffer_QueryInterface(LPRPCCHANNELBUFFER iface, REFIID riid, LPVOID *ppv)
|
||||||
{
|
{
|
||||||
|
@ -207,6 +356,11 @@ static HRESULT WINAPI ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface,
|
||||||
RPC_STATUS status;
|
RPC_STATUS status;
|
||||||
ORPCTHIS *orpcthis;
|
ORPCTHIS *orpcthis;
|
||||||
struct message_state *message_state;
|
struct message_state *message_state;
|
||||||
|
ULONG extensions_size;
|
||||||
|
struct channel_hook_buffer_data *channel_hook_data;
|
||||||
|
unsigned int channel_hook_count;
|
||||||
|
ULONG extension_count;
|
||||||
|
SChannelHookCallInfo channel_hook_info;
|
||||||
|
|
||||||
TRACE("(%p)->(%p,%s)\n", This, olemsg, debugstr_guid(riid));
|
TRACE("(%p)->(%p,%s)\n", This, olemsg, debugstr_guid(riid));
|
||||||
|
|
||||||
|
@ -230,7 +384,23 @@ static HRESULT WINAPI ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface,
|
||||||
msg->Handle = This->bind;
|
msg->Handle = This->bind;
|
||||||
msg->RpcInterfaceInformation = cif;
|
msg->RpcInterfaceInformation = cif;
|
||||||
|
|
||||||
|
channel_hook_info.iid = *riid;
|
||||||
|
channel_hook_info.cbSize = sizeof(channel_hook_info);
|
||||||
|
channel_hook_info.uCausality = GUID_NULL; /* FIXME */
|
||||||
|
channel_hook_info.dwServerPid = 0; /* FIXME */
|
||||||
|
channel_hook_info.iMethod = msg->ProcNum;
|
||||||
|
channel_hook_info.pObject = NULL; /* only present on server-side */
|
||||||
|
|
||||||
|
extensions_size = ChannelHooks_ClientGetSize(&channel_hook_info,
|
||||||
|
&channel_hook_data, &channel_hook_count, &extension_count);
|
||||||
|
|
||||||
msg->BufferLength += FIELD_OFFSET(ORPCTHIS, extensions) + 4;
|
msg->BufferLength += FIELD_OFFSET(ORPCTHIS, extensions) + 4;
|
||||||
|
if (extensions_size)
|
||||||
|
{
|
||||||
|
msg->BufferLength += FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent) + 2*sizeof(DWORD) + extensions_size;
|
||||||
|
if (extension_count & 1)
|
||||||
|
msg->BufferLength += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]);
|
||||||
|
}
|
||||||
|
|
||||||
status = I_RpcGetBuffer(msg);
|
status = I_RpcGetBuffer(msg);
|
||||||
|
|
||||||
|
@ -245,14 +415,42 @@ static HRESULT WINAPI ClientRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface,
|
||||||
|
|
||||||
orpcthis->version.MajorVersion = COM_MAJOR_VERSION;
|
orpcthis->version.MajorVersion = COM_MAJOR_VERSION;
|
||||||
orpcthis->version.MinorVersion = COM_MINOR_VERSION;
|
orpcthis->version.MinorVersion = COM_MINOR_VERSION;
|
||||||
orpcthis->flags = ORPCF_NULL;
|
orpcthis->flags = channel_hook_info.dwServerPid ? ORPCF_LOCAL : ORPCF_NULL;
|
||||||
orpcthis->reserved1 = 0;
|
orpcthis->reserved1 = 0;
|
||||||
orpcthis->cid = GUID_NULL; /* FIXME */
|
orpcthis->cid = channel_hook_info.uCausality;
|
||||||
|
|
||||||
/* NDR representation of orpcthis->extensions */
|
/* NDR representation of orpcthis->extensions */
|
||||||
*(DWORD *)msg->Buffer = 0; /* FIXME */
|
*(DWORD *)msg->Buffer = extensions_size ? 1 : 0;
|
||||||
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
|
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
|
||||||
|
|
||||||
|
if (extensions_size)
|
||||||
|
{
|
||||||
|
ORPC_EXTENT_ARRAY *orpc_extent_array = msg->Buffer;
|
||||||
|
orpc_extent_array->size = extension_count;
|
||||||
|
orpc_extent_array->reserved = 0;
|
||||||
|
msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent);
|
||||||
|
/* NDR representation of orpc_extent_array->extent */
|
||||||
|
*(DWORD *)msg->Buffer = 1;
|
||||||
|
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
|
||||||
|
/* NDR representation of [size_is] attribute of orpc_extent_array->extent */
|
||||||
|
*(DWORD *)msg->Buffer = (extension_count + 1) & ~1;
|
||||||
|
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
|
||||||
|
|
||||||
|
msg->Buffer = ChannelHooks_ClientFillBuffer(&channel_hook_info,
|
||||||
|
msg->Buffer, channel_hook_data, channel_hook_count);
|
||||||
|
|
||||||
|
/* we must add a dummy extension if there is an odd extension
|
||||||
|
* count to meet the contract specified by the size_is attribute */
|
||||||
|
if (extension_count & 1)
|
||||||
|
{
|
||||||
|
WIRE_ORPC_EXTENT *wire_orpc_extent = msg->Buffer;
|
||||||
|
wire_orpc_extent->conformance = 0;
|
||||||
|
memcpy(&wire_orpc_extent->id, &GUID_NULL, sizeof(wire_orpc_extent->id));
|
||||||
|
wire_orpc_extent->size = 0;
|
||||||
|
msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* store the prefixed data length so that we can restore the real buffer
|
/* store the prefixed data length so that we can restore the real buffer
|
||||||
* pointer in ClientRpcChannelBuffer_SendReceive. */
|
* pointer in ClientRpcChannelBuffer_SendReceive. */
|
||||||
message_state->prefix_data_len = (char *)msg->Buffer - (char *)orpcthis;
|
message_state->prefix_data_len = (char *)msg->Buffer - (char *)orpcthis;
|
||||||
|
|
Loading…
Reference in New Issue