diff --git a/dlls/rpcrt4/ndr_contexthandle.c b/dlls/rpcrt4/ndr_contexthandle.c index 21efaeeb07e..0b3eeb1fcd4 100644 --- a/dlls/rpcrt4/ndr_contexthandle.c +++ b/dlls/rpcrt4/ndr_contexthandle.c @@ -20,6 +20,7 @@ */ #include "ndr_misc.h" +#include "rpc_assoc.h" #include "rpcndr.h" #include "wine/rpcfc.h" @@ -33,7 +34,7 @@ WINE_DEFAULT_DEBUG_CHANNEL(ole); typedef struct ndr_context_handle { - DWORD attributes; + ULONG attributes; GUID uuid; } ndr_context_handle; @@ -212,35 +213,66 @@ void WINAPI NDRCContextUnmarshall(NDR_CCONTEXT *CContext, /*********************************************************************** * NDRSContextMarshall [RPCRT4.@] */ -void WINAPI NDRSContextMarshall(NDR_SCONTEXT CContext, +void WINAPI NDRSContextMarshall(NDR_SCONTEXT SContext, void *pBuff, NDR_RUNDOWN userRunDownIn) { - FIXME("(%p %p %p): stub\n", CContext, pBuff, userRunDownIn); + TRACE("(%p %p %p)\n", SContext, pBuff, userRunDownIn); + NDRSContextMarshall2(I_RpcGetCurrentCallHandle(), SContext, pBuff, userRunDownIn, NULL, 0); } /*********************************************************************** * NDRSContextMarshallEx [RPCRT4.@] */ void WINAPI NDRSContextMarshallEx(RPC_BINDING_HANDLE hBinding, - NDR_SCONTEXT CContext, + NDR_SCONTEXT SContext, void *pBuff, NDR_RUNDOWN userRunDownIn) { - FIXME("(%p %p %p %p): stub\n", hBinding, CContext, pBuff, userRunDownIn); + TRACE("(%p %p %p %p)\n", hBinding, SContext, pBuff, userRunDownIn); + NDRSContextMarshall2(hBinding, SContext, pBuff, userRunDownIn, NULL, 0); } /*********************************************************************** * NDRSContextMarshall2 [RPCRT4.@] */ void WINAPI NDRSContextMarshall2(RPC_BINDING_HANDLE hBinding, - NDR_SCONTEXT CContext, + NDR_SCONTEXT SContext, void *pBuff, NDR_RUNDOWN userRunDownIn, void *CtxGuard, ULONG Flags) { - FIXME("(%p %p %p %p %p %u): stub\n", - hBinding, CContext, pBuff, userRunDownIn, CtxGuard, Flags); + RpcBinding *binding = hBinding; + RPC_STATUS status; + ndr_context_handle *ndr = pBuff; + + TRACE("(%p %p %p %p %p %u)\n", + hBinding, SContext, pBuff, userRunDownIn, CtxGuard, Flags); + + if (!binding->server || !binding->Assoc) + RpcRaiseException(ERROR_INVALID_HANDLE); + + if (SContext->userContext) + { + status = RpcServerAssoc_UpdateContextHandle(binding->Assoc, SContext, CtxGuard, userRunDownIn); + if (status != RPC_S_OK) + RpcRaiseException(status); + ndr->attributes = 0; + RpcContextHandle_GetUuid(SContext, &ndr->uuid); + } + else + { + if (!RpcContextHandle_IsGuardCorrect(SContext, CtxGuard)) + RpcRaiseException(ERROR_INVALID_HANDLE); + memset(ndr, 0, sizeof(*ndr)); + /* Note: release the context handle twice in this case to release + * one ref being kept around for the data and one ref for the + * unmarshall/marshall sequence */ + if (!RpcServerAssoc_ReleaseContextHandle(binding->Assoc, SContext, FALSE)) + return; /* this is to cope with the case of the data not being valid + * before and so not having a further reference */ + } + RpcServerAssoc_ReleaseContextHandle(binding->Assoc, SContext, TRUE); } /*********************************************************************** @@ -249,8 +281,8 @@ void WINAPI NDRSContextMarshall2(RPC_BINDING_HANDLE hBinding, NDR_SCONTEXT WINAPI NDRSContextUnmarshall(void *pBuff, ULONG DataRepresentation) { - FIXME("(%p %08x): stub\n", pBuff, DataRepresentation); - return NULL; + TRACE("(%p %08x)\n", pBuff, DataRepresentation); + return NDRSContextUnmarshall2(I_RpcGetCurrentCallHandle(), pBuff, DataRepresentation, NULL, 0); } /*********************************************************************** @@ -260,8 +292,8 @@ NDR_SCONTEXT WINAPI NDRSContextUnmarshallEx(RPC_BINDING_HANDLE hBinding, void *pBuff, ULONG DataRepresentation) { - FIXME("(%p %p %08x): stub\n", hBinding, pBuff, DataRepresentation); - return NULL; + TRACE("(%p %p %08x)\n", hBinding, pBuff, DataRepresentation); + return NDRSContextUnmarshall2(hBinding, pBuff, DataRepresentation, NULL, 0); } /*********************************************************************** @@ -272,7 +304,36 @@ NDR_SCONTEXT WINAPI NDRSContextUnmarshall2(RPC_BINDING_HANDLE hBinding, ULONG DataRepresentation, void *CtxGuard, ULONG Flags) { - FIXME("(%p %p %08x %p %u): stub\n", + RpcBinding *binding = hBinding; + NDR_SCONTEXT SContext; + RPC_STATUS status; + + TRACE("(%p %p %08x %p %u)\n", hBinding, pBuff, DataRepresentation, CtxGuard, Flags); - return NULL; + + if (!binding->server || !binding->Assoc) + RpcRaiseException(ERROR_INVALID_HANDLE); + + if (!pBuff) + status = RpcServerAssoc_AllocateContextHandle(binding->Assoc, CtxGuard, + &SContext); + else + { + const ndr_context_handle *context_ndr = pBuff; + if (context_ndr->attributes) + { + ERR("non-null attributes 0x%x\n", context_ndr->attributes); + status = ERROR_INVALID_HANDLE; + } + else + status = RpcServerAssoc_FindContextHandle(binding->Assoc, + &context_ndr->uuid, + CtxGuard, Flags, + &SContext); + } + + if (status != RPC_S_OK) + RpcRaiseException(status); + + return SContext; } diff --git a/dlls/rpcrt4/rpc_assoc.c b/dlls/rpcrt4/rpc_assoc.c index 7793c48e435..1bffb8564ed 100644 --- a/dlls/rpcrt4/rpc_assoc.c +++ b/dlls/rpcrt4/rpc_assoc.c @@ -24,6 +24,7 @@ #include "rpc.h" #include "rpcndr.h" +#include "winternl.h" #include "wine/unicode.h" #include "wine/debug.h" @@ -48,6 +49,19 @@ static struct list server_assoc_list = LIST_INIT(server_assoc_list); static LONG last_assoc_group_id; +typedef struct _RpcContextHandle +{ + struct list entry; + void *user_context; + NDR_RUNDOWN rundown_routine; + void *ctx_guard; + UUID uuid; + RTL_RWLOCK rw_lock; + unsigned int refs; +} RpcContextHandle; + +static void RpcContextHandle_Destroy(RpcContextHandle *context_handle); + static RPC_STATUS RpcAssoc_Alloc(LPCSTR Protseq, LPCSTR NetworkAddr, LPCSTR Endpoint, LPCWSTR NetworkOptions, RpcAssoc **assoc_out) @@ -58,6 +72,7 @@ static RPC_STATUS RpcAssoc_Alloc(LPCSTR Protseq, LPCSTR NetworkAddr, return RPC_S_OUT_OF_RESOURCES; assoc->refs = 1; list_init(&assoc->free_connection_pool); + list_init(&assoc->context_handle_list); InitializeCriticalSection(&assoc->cs); assoc->Protseq = RPCRT4_strdupA(Protseq); assoc->NetworkAddr = RPCRT4_strdupA(NetworkAddr); @@ -171,6 +186,7 @@ ULONG RpcAssoc_Release(RpcAssoc *assoc) if (!refs) { RpcConnection *Connection, *cursor2; + RpcContextHandle *context_handle, *context_handle_cursor; TRACE("destroying assoc %p\n", assoc); @@ -180,6 +196,9 @@ ULONG RpcAssoc_Release(RpcAssoc *assoc) RPCRT4_DestroyConnection(Connection); } + LIST_FOR_EACH_ENTRY_SAFE(context_handle, context_handle_cursor, &assoc->context_handle_list, RpcContextHandle, entry) + RpcContextHandle_Destroy(context_handle); + HeapFree(GetProcessHeap(), 0, assoc->NetworkOptions); HeapFree(GetProcessHeap(), 0, assoc->Endpoint); HeapFree(GetProcessHeap(), 0, assoc->NetworkAddr); @@ -391,3 +410,130 @@ void RpcAssoc_ReleaseIdleConnection(RpcAssoc *assoc, RpcConnection *Connection) list_add_head(&assoc->free_connection_pool, &Connection->conn_pool_entry); LeaveCriticalSection(&assoc->cs); } + +RPC_STATUS RpcServerAssoc_AllocateContextHandle(RpcAssoc *assoc, void *CtxGuard, + NDR_SCONTEXT *SContext) +{ + RpcContextHandle *context_handle; + + context_handle = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(*context_handle)); + if (!context_handle) + return ERROR_OUTOFMEMORY; + + context_handle->ctx_guard = CtxGuard; + RtlInitializeResource(&context_handle->rw_lock); + context_handle->refs = 1; + + /* lock here to mirror unmarshall, so we don't need to special-case the + * freeing of a non-marshalled context handle */ + RtlAcquireResourceExclusive(&context_handle->rw_lock, TRUE); + + EnterCriticalSection(&assoc->cs); + list_add_tail(&assoc->context_handle_list, &context_handle->entry); + LeaveCriticalSection(&assoc->cs); + + *SContext = (NDR_SCONTEXT)context_handle; + return RPC_S_OK; +} + +BOOL RpcContextHandle_IsGuardCorrect(NDR_SCONTEXT SContext, void *CtxGuard) +{ + RpcContextHandle *context_handle = (RpcContextHandle *)SContext; + return context_handle->ctx_guard == CtxGuard; +} + +RPC_STATUS RpcServerAssoc_FindContextHandle(RpcAssoc *assoc, const UUID *uuid, + void *CtxGuard, ULONG Flags, NDR_SCONTEXT *SContext) +{ + RpcContextHandle *context_handle; + + EnterCriticalSection(&assoc->cs); + LIST_FOR_EACH_ENTRY(context_handle, &assoc->context_handle_list, RpcContextHandle, entry) + { + if (RpcContextHandle_IsGuardCorrect((NDR_SCONTEXT)context_handle, CtxGuard) && + !memcmp(&context_handle->uuid, uuid, sizeof(*uuid))) + { + *SContext = (NDR_SCONTEXT)context_handle; + if (context_handle->refs++) + { + LeaveCriticalSection(&assoc->cs); + TRACE("found %p\n", context_handle); + RtlAcquireResourceExclusive(&context_handle->rw_lock, TRUE); + return RPC_S_OK; + } + } + } + LeaveCriticalSection(&assoc->cs); + + ERR("no context handle found for uuid %s, guard %p\n", + debugstr_guid(uuid), CtxGuard); + return ERROR_INVALID_HANDLE; +} + +RPC_STATUS RpcServerAssoc_UpdateContextHandle(RpcAssoc *assoc, + NDR_SCONTEXT SContext, + void *CtxGuard, + NDR_RUNDOWN rundown_routine) +{ + RpcContextHandle *context_handle = (RpcContextHandle *)SContext; + RPC_STATUS status; + + if (!RpcContextHandle_IsGuardCorrect((NDR_SCONTEXT)context_handle, CtxGuard)) + return ERROR_INVALID_HANDLE; + + EnterCriticalSection(&assoc->cs); + if (UuidIsNil(&context_handle->uuid, &status)) + { + /* add a ref for the data being valid */ + context_handle->refs++; + UuidCreate(&context_handle->uuid); + context_handle->rundown_routine = rundown_routine; + TRACE("allocated uuid %s for context handle %p\n", + debugstr_guid(&context_handle->uuid), context_handle); + } + LeaveCriticalSection(&assoc->cs); + + return RPC_S_OK; +} + +void RpcContextHandle_GetUuid(NDR_SCONTEXT SContext, UUID *uuid) +{ + RpcContextHandle *context_handle = (RpcContextHandle *)SContext; + *uuid = context_handle->uuid; +} + +static void RpcContextHandle_Destroy(RpcContextHandle *context_handle) +{ + TRACE("freeing %p\n", context_handle); + + if (context_handle->user_context && context_handle->rundown_routine) + { + TRACE("calling rundown routine %p with user context %p\n", + context_handle->rundown_routine, context_handle->user_context); + context_handle->rundown_routine(context_handle->user_context); + } + + RtlDeleteResource(&context_handle->rw_lock); + + HeapFree(GetProcessHeap(), 0, context_handle); +} + +unsigned int RpcServerAssoc_ReleaseContextHandle(RpcAssoc *assoc, NDR_SCONTEXT SContext, BOOL release_lock) +{ + RpcContextHandle *context_handle = (RpcContextHandle *)SContext; + unsigned int refs; + + if (release_lock) + RtlReleaseResource(&context_handle->rw_lock); + + EnterCriticalSection(&assoc->cs); + refs = --context_handle->refs; + if (!refs) + list_remove(&context_handle->entry); + LeaveCriticalSection(&assoc->cs); + + if (!refs) + RpcContextHandle_Destroy(context_handle); + + return refs; +} diff --git a/dlls/rpcrt4/rpc_assoc.h b/dlls/rpcrt4/rpc_assoc.h index a943324902e..1ce1f135638 100644 --- a/dlls/rpcrt4/rpc_assoc.h +++ b/dlls/rpcrt4/rpc_assoc.h @@ -19,6 +19,7 @@ * */ +#include "rpc_binding.h" #include "wine/list.h" typedef struct _RpcAssoc @@ -35,8 +36,13 @@ typedef struct _RpcAssoc ULONG assoc_group_id; CRITICAL_SECTION cs; - /* connections available to be used */ + + /* client-only */ + /* connections available to be used (protected by cs) */ struct list free_connection_pool; + + /* server-only */ + struct list context_handle_list; /* protected by cs */ } RpcAssoc; RPC_STATUS RPCRT4_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr, LPCSTR Endpoint, LPCWSTR NetworkOptions, RpcAssoc **assoc); @@ -44,3 +50,9 @@ RPC_STATUS RpcAssoc_GetClientConnection(RpcAssoc *assoc, const RPC_SYNTAX_IDENTI void RpcAssoc_ReleaseIdleConnection(RpcAssoc *assoc, RpcConnection *Connection); ULONG RpcAssoc_Release(RpcAssoc *assoc); RPC_STATUS RpcServerAssoc_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr, LPCSTR Endpoint, LPCWSTR NetworkOptions, unsigned long assoc_gid, RpcAssoc **assoc_out); +RPC_STATUS RpcServerAssoc_AllocateContextHandle(RpcAssoc *assoc, void *CtxGuard, NDR_SCONTEXT *SContext); +RPC_STATUS RpcServerAssoc_FindContextHandle(RpcAssoc *assoc, const UUID *uuid, void *CtxGuard, ULONG Flags, NDR_SCONTEXT *SContext); +RPC_STATUS RpcServerAssoc_UpdateContextHandle(RpcAssoc *assoc, NDR_SCONTEXT SContext, void *CtxGuard, NDR_RUNDOWN rundown_routine); +unsigned int RpcServerAssoc_ReleaseContextHandle(RpcAssoc *assoc, NDR_SCONTEXT SContext, BOOL release_lock); +void RpcContextHandle_GetUuid(NDR_SCONTEXT SContext, UUID *uuid); +BOOL RpcContextHandle_IsGuardCorrect(NDR_SCONTEXT SContext, void *CtxGuard);