ole32: Marshal the ORPCTHAT structure prefixed to the server data.

Unmarshal the data on the client side (during
ClientChannelBuffer_SendReceive) and call ClientNotify.
This commit is contained in:
Rob Shearman 2006-12-27 13:04:07 +00:00 committed by Alexandre Julliard
parent c7e00c9f49
commit ee99b6d743
2 changed files with 317 additions and 52 deletions

View File

@ -36,7 +36,6 @@
*
* - Make all ole interface marshaling use NDR to be wire compatible with
* native DCOM
* - Use & interpret ORPCTHIS & ORPCTHAT.
*
*/

View File

@ -154,6 +154,9 @@ struct channel_hook_buffer_data
};
static HRESULT unmarshal_ORPCTHAT(RPC_MESSAGE *msg, ORPCTHAT *orpcthat,
ORPC_EXTENT_ARRAY *orpc_ext_array, WIRE_ORPC_EXTENT **first_wire_orpc_extent);
/* Channel Hook Functions */
static ULONG ChannelHooks_ClientGetSize(SChannelHookCallInfo *info,
@ -172,7 +175,7 @@ static ULONG ChannelHooks_ClientGetSize(SChannelHookCallInfo *info,
LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
(*hook_count)++;
if (hook_count)
if (*hook_count)
*data = HeapAlloc(GetProcessHeap(), 0, *hook_count * sizeof(struct channel_hook_buffer_data));
else
*data = NULL;
@ -276,6 +279,128 @@ static void ChannelHooks_ServerNotify(SChannelHookCallInfo *info,
LeaveCriticalSection(&csChannelHook);
}
static ULONG ChannelHooks_ServerGetSize(SChannelHookCallInfo *info,
struct channel_hook_buffer_data **data, unsigned int *hook_count,
ULONG *extension_count)
{
struct channel_hook_entry *entry;
ULONG total_size = 0;
unsigned int hook_index = 0;
*hook_count = 0;
*extension_count = 0;
EnterCriticalSection(&csChannelHook);
LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
(*hook_count)++;
if (*hook_count)
*data = HeapAlloc(GetProcessHeap(), 0, *hook_count * sizeof(struct channel_hook_buffer_data));
else
*data = NULL;
LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
{
ULONG extension_size = 0;
IChannelHook_ServerGetSize(entry->hook, &entry->id, &info->iid, S_OK,
&extension_size);
TRACE("%s: extension_size = %u\n", debugstr_guid(&entry->id), extension_size);
extension_size = (extension_size+7)&~7;
(*data)[hook_index].id = entry->id;
(*data)[hook_index].extension_size = extension_size;
/* an extension is only put onto the wire if it has data to write */
if (extension_size)
{
total_size += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[extension_size]);
(*extension_count)++;
}
hook_index++;
}
LeaveCriticalSection(&csChannelHook);
return total_size;
}
static unsigned char * ChannelHooks_ServerFillBuffer(SChannelHookCallInfo *info,
unsigned char *buffer, struct channel_hook_buffer_data *data,
unsigned int hook_count)
{
struct channel_hook_entry *entry;
EnterCriticalSection(&csChannelHook);
LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
{
unsigned int i;
ULONG extension_size = 0;
WIRE_ORPC_EXTENT *wire_orpc_extent = (WIRE_ORPC_EXTENT *)buffer;
for (i = 0; i < hook_count; i++)
if (IsEqualGUID(&entry->id, &data[i].id))
extension_size = data[i].extension_size;
/* an extension is only put onto the wire if it has data to write */
if (!extension_size)
continue;
IChannelHook_ServerFillBuffer(entry->hook, &entry->id, &info->iid,
&extension_size, buffer + FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]),
S_OK);
TRACE("%s: extension_size = %u\n", debugstr_guid(&entry->id), extension_size);
/* FIXME: set unused portion of wire_orpc_extent->data to 0? */
wire_orpc_extent->conformance = (extension_size+7)&~7;
wire_orpc_extent->size = extension_size;
memcpy(&wire_orpc_extent->id, &entry->id, sizeof(wire_orpc_extent->id));
buffer += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[wire_orpc_extent->conformance]);
}
LeaveCriticalSection(&csChannelHook);
HeapFree(GetProcessHeap(), 0, data);
return buffer;
}
static void ChannelHooks_ClientNotify(SChannelHookCallInfo *info,
DWORD lDataRep, WIRE_ORPC_EXTENT *first_wire_orpc_extent,
ULONG extension_count, HRESULT hrFault)
{
struct channel_hook_entry *entry;
ULONG i;
EnterCriticalSection(&csChannelHook);
LIST_FOR_EACH_ENTRY(entry, &channel_hooks, struct channel_hook_entry, entry)
{
WIRE_ORPC_EXTENT *wire_orpc_extent;
for (i = 0, wire_orpc_extent = first_wire_orpc_extent;
i < extension_count;
i++, wire_orpc_extent = (WIRE_ORPC_EXTENT *)&wire_orpc_extent->data[wire_orpc_extent->conformance])
{
if (IsEqualGUID(&entry->id, &wire_orpc_extent->id))
break;
}
if (i == extension_count) wire_orpc_extent = NULL;
IChannelHook_ClientNotify(entry->hook, &entry->id, &info->iid,
wire_orpc_extent ? wire_orpc_extent->size : 0,
wire_orpc_extent ? wire_orpc_extent->data : NULL,
lDataRep, hrFault);
}
LeaveCriticalSection(&csChannelHook);
}
HRESULT RPC_RegisterChannelHook(REFGUID rguid, IChannelHook *hook)
{
struct channel_hook_entry *entry;
@ -361,7 +486,12 @@ static HRESULT WINAPI ServerRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface,
RpcChannelBuffer *This = (RpcChannelBuffer *)iface;
RPC_MESSAGE *msg = (RPC_MESSAGE *)olemsg;
RPC_STATUS status;
ORPCTHAT *orpcthat;
struct message_state *message_state;
ULONG extensions_size;
struct channel_hook_buffer_data *channel_hook_data;
unsigned int channel_hook_count;
ULONG extension_count;
TRACE("(%p)->(%p,%s)\n", This, olemsg, debugstr_guid(riid));
@ -370,11 +500,62 @@ static HRESULT WINAPI ServerRpcChannelBuffer_GetBuffer(LPRPCCHANNELBUFFER iface,
msg->Handle = message_state->binding_handle;
msg->Buffer = (char *)msg->Buffer - message_state->prefix_data_len;
extensions_size = ChannelHooks_ServerGetSize(&message_state->channel_hook_info,
&channel_hook_data, &channel_hook_count, &extension_count);
msg->BufferLength += FIELD_OFFSET(ORPCTHAT, extensions) + 4;
if (extensions_size)
{
msg->BufferLength += FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent) + 2*sizeof(DWORD) + extensions_size;
if (extension_count & 1)
msg->BufferLength += FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]);
}
status = I_RpcGetBuffer(msg);
orpcthat = (ORPCTHAT *)msg->Buffer;
msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPCTHAT, extensions);
orpcthat->flags = ORPCF_NULL /* FIXME? */;
/* NDR representation of orpcthat->extensions */
*(DWORD *)msg->Buffer = extensions_size ? 1 : 0;
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
if (extensions_size)
{
ORPC_EXTENT_ARRAY *orpc_extent_array = msg->Buffer;
orpc_extent_array->size = extension_count;
orpc_extent_array->reserved = 0;
msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent);
/* NDR representation of orpc_extent_array->extent */
*(DWORD *)msg->Buffer = 1;
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
/* NDR representation of [size_is] attribute of orpc_extent_array->extent */
*(DWORD *)msg->Buffer = (extension_count + 1) & ~1;
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
msg->Buffer = ChannelHooks_ServerFillBuffer(&message_state->channel_hook_info,
msg->Buffer, channel_hook_data, channel_hook_count);
/* we must add a dummy extension if there is an odd extension
* count to meet the contract specified by the size_is attribute */
if (extension_count & 1)
{
WIRE_ORPC_EXTENT *wire_orpc_extent = msg->Buffer;
wire_orpc_extent->conformance = 0;
memcpy(&wire_orpc_extent->id, &GUID_NULL, sizeof(wire_orpc_extent->id));
wire_orpc_extent->size = 0;
msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(WIRE_ORPC_EXTENT, data[0]);
}
}
/* store the prefixed data length so that we can restore the real buffer
* later */
message_state->prefix_data_len = (char *)msg->Buffer - (char *)orpcthat;
msg->BufferLength -= message_state->prefix_data_len;
/* save away the message state again */
msg->Handle = message_state;
message_state->prefix_data_len = 0;
TRACE("-- %ld\n", status);
@ -556,6 +737,9 @@ static HRESULT WINAPI ClientRpcChannelBuffer_SendReceive(LPRPCCHANNELBUFFER ifac
APARTMENT *apt = NULL;
IPID ipid;
struct message_state *message_state;
ORPCTHAT orpcthat;
ORPC_EXTENT_ARRAY orpc_ext_array;
WIRE_ORPC_EXTENT *first_wire_orpc_extent = NULL;
TRACE("(%p) iMethod=%d\n", olemsg, olemsg->iMethod);
@ -652,18 +836,43 @@ static HRESULT WINAPI ClientRpcChannelBuffer_SendReceive(LPRPCCHANNELBUFFER ifac
}
ClientRpcChannelBuffer_ReleaseEventHandle(This, params->handle);
/* save away the message state again */
msg->Handle = message_state;
message_state->prefix_data_len = 0;
if (hr == S_OK) hr = params->hr;
status = params->status;
HeapFree(GetProcessHeap(), 0, params);
params = NULL;
orpcthat.flags = ORPCF_NULL;
orpcthat.extensions = NULL;
if (status == RPC_S_OK && msg->BufferLength > FIELD_OFFSET(ORPCTHAT, extensions) + 4)
{
HRESULT hr2;
char *original_buffer = msg->Buffer;
/* handle ORPCTHAT and client extensions */
hr2 = unmarshal_ORPCTHAT(msg, &orpcthat, &orpc_ext_array, &first_wire_orpc_extent);
if (FAILED(hr2))
hr = hr2;
message_state->prefix_data_len = original_buffer - (char *)msg->Buffer;
msg->BufferLength -= message_state->prefix_data_len;
}
else
message_state->prefix_data_len = 0;
ChannelHooks_ClientNotify(&message_state->channel_hook_info,
msg->DataRepresentation,
first_wire_orpc_extent,
orpcthat.extensions && first_wire_orpc_extent ? orpcthat.extensions->size : 0,
status == RPC_S_CALL_FAILED && msg->BufferLength >= sizeof(HRESULT) ? *(HRESULT *)msg->Buffer : S_OK);
/* save away the message state again */
msg->Handle = message_state;
if (hr) return hr;
if (pstatus) *pstatus = status;
TRACE("RPC call status: 0x%lx\n", status);
@ -856,6 +1065,60 @@ HRESULT RPC_CreateServerChannel(IRpcChannelBuffer **chan)
return S_OK;
}
/* unmarshals ORPC_EXTENT_ARRAY according to NDR rules, but doesn't allocate
* any memory */
static HRESULT unmarshal_ORPC_EXTENT_ARRAY(RPC_MESSAGE *msg, const char *end,
ORPC_EXTENT_ARRAY *extensions,
WIRE_ORPC_EXTENT **first_wire_orpc_extent)
{
DWORD pointer_id;
DWORD i;
memcpy(extensions, msg->Buffer, FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent));
msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent);
if ((const char *)msg->Buffer + 2 * sizeof(DWORD) > end)
return RPC_E_INVALID_HEADER;
pointer_id = *(DWORD *)msg->Buffer;
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
extensions->extent = NULL;
if (pointer_id)
{
WIRE_ORPC_EXTENT *wire_orpc_extent;
/* conformance */
if (*(DWORD *)msg->Buffer != ((extensions->size+1)&~1))
return RPC_S_INVALID_BOUND;
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
/* arbritary limit for security (don't know what native does) */
if (extensions->size > 256)
{
ERR("too many extensions: %ld\n", extensions->size);
return RPC_S_INVALID_BOUND;
}
*first_wire_orpc_extent = wire_orpc_extent = (WIRE_ORPC_EXTENT *)msg->Buffer;
for (i = 0; i < ((extensions->size+1)&~1); i++)
{
if ((const char *)&wire_orpc_extent->data[0] > end)
return RPC_S_INVALID_BOUND;
if (wire_orpc_extent->conformance != ((wire_orpc_extent->size+7)&~7))
return RPC_S_INVALID_BOUND;
if ((const char *)&wire_orpc_extent->data[wire_orpc_extent->conformance] > end)
return RPC_S_INVALID_BOUND;
TRACE("size %u, guid %s\n", wire_orpc_extent->size, debugstr_guid(&wire_orpc_extent->id));
wire_orpc_extent = (WIRE_ORPC_EXTENT *)&wire_orpc_extent->data[wire_orpc_extent->conformance];
}
msg->Buffer = wire_orpc_extent;
}
return S_OK;
}
/* unmarshals ORPCTHIS according to NDR rules, but doesn't allocate any memory */
static HRESULT unmarshal_ORPCTHIS(RPC_MESSAGE *msg, ORPCTHIS *orpcthis,
ORPC_EXTENT_ARRAY *orpc_ext_array, WIRE_ORPC_EXTENT **first_wire_orpc_extent)
@ -885,50 +1148,10 @@ static HRESULT unmarshal_ORPCTHIS(RPC_MESSAGE *msg, ORPCTHIS *orpcthis,
if (orpcthis->extensions)
{
DWORD pointer_id;
DWORD i;
memcpy(orpcthis->extensions, msg->Buffer, FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent));
msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPC_EXTENT_ARRAY, extent);
if ((const char *)msg->Buffer + 2 * sizeof(DWORD) > end)
return RPC_E_INVALID_HEADER;
pointer_id = *(DWORD *)msg->Buffer;
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
orpcthis->extensions->extent = NULL;
if (pointer_id)
{
WIRE_ORPC_EXTENT *wire_orpc_extent;
/* conformance */
if (*(DWORD *)msg->Buffer != ((orpcthis->extensions->size+1)&~1))
return RPC_S_INVALID_BOUND;
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
/* arbritary limit for security (don't know what native does) */
if (orpcthis->extensions->size > 256)
{
ERR("too many extensions: %ld\n", orpcthis->extensions->size);
return RPC_S_INVALID_BOUND;
}
*first_wire_orpc_extent = wire_orpc_extent = (WIRE_ORPC_EXTENT *)msg->Buffer;
for (i = 0; i < ((orpcthis->extensions->size+1)&~1); i++)
{
if ((const char *)&wire_orpc_extent->data[0] > end)
return RPC_S_INVALID_BOUND;
if (wire_orpc_extent->conformance != ((wire_orpc_extent->size+7)&~7))
return RPC_S_INVALID_BOUND;
if ((const char *)&wire_orpc_extent->data[wire_orpc_extent->conformance] > end)
return RPC_S_INVALID_BOUND;
TRACE("size %u, guid %s\n", wire_orpc_extent->size, debugstr_guid(&wire_orpc_extent->id));
wire_orpc_extent = (WIRE_ORPC_EXTENT *)&wire_orpc_extent->data[wire_orpc_extent->conformance];
}
msg->Buffer = wire_orpc_extent;
}
HRESULT hr = unmarshal_ORPC_EXTENT_ARRAY(msg, end, orpc_ext_array,
first_wire_orpc_extent);
if (FAILED(hr))
return hr;
}
if ((orpcthis->version.MajorVersion != COM_MAJOR_VERSION) ||
@ -948,6 +1171,49 @@ static HRESULT unmarshal_ORPCTHIS(RPC_MESSAGE *msg, ORPCTHIS *orpcthis,
return S_OK;
}
static HRESULT unmarshal_ORPCTHAT(RPC_MESSAGE *msg, ORPCTHAT *orpcthat,
ORPC_EXTENT_ARRAY *orpc_ext_array, WIRE_ORPC_EXTENT **first_wire_orpc_extent)
{
const char *end = (char *)msg->Buffer + msg->BufferLength;
*first_wire_orpc_extent = NULL;
if (msg->BufferLength < FIELD_OFFSET(ORPCTHAT, extensions) + 4)
{
ERR("invalid buffer length\n");
return RPC_E_INVALID_HEADER;
}
memcpy(orpcthat, msg->Buffer, FIELD_OFFSET(ORPCTHAT, extensions));
msg->Buffer = (char *)msg->Buffer + FIELD_OFFSET(ORPCTHAT, extensions);
if ((const char *)msg->Buffer + sizeof(DWORD) > end)
return RPC_E_INVALID_HEADER;
if (*(DWORD *)msg->Buffer)
orpcthat->extensions = orpc_ext_array;
else
orpcthat->extensions = NULL;
msg->Buffer = (char *)msg->Buffer + sizeof(DWORD);
if (orpcthat->extensions)
{
HRESULT hr = unmarshal_ORPC_EXTENT_ARRAY(msg, end, orpc_ext_array,
first_wire_orpc_extent);
if (FAILED(hr))
return hr;
}
if (orpcthat->flags & ~(ORPCF_LOCAL|ORPCF_RESERVED1|ORPCF_RESERVED2|ORPCF_RESERVED3|ORPCF_RESERVED4))
{
ERR("invalid flags 0x%lx\n", orpcthat->flags & ~(ORPCF_LOCAL|ORPCF_RESERVED1|ORPCF_RESERVED2|ORPCF_RESERVED3|ORPCF_RESERVED4));
return RPC_E_INVALID_HEADER;
}
return S_OK;
}
void RPC_ExecuteCall(struct dispatch_params *params)
{
struct message_state *message_state = NULL;