diff --git a/dlls/rpcrt4/rpc_binding.c b/dlls/rpcrt4/rpc_binding.c index eb6fb14e915..b554e7f4d6e 100644 --- a/dlls/rpcrt4/rpc_binding.c +++ b/dlls/rpcrt4/rpc_binding.c @@ -139,6 +139,8 @@ static RPC_STATUS RPCRT4_CreateBindingW(RpcBinding** Binding, BOOL server, LPWST static RPC_STATUS RPCRT4_CompleteBindingA(RpcBinding* Binding, LPSTR NetworkAddr, LPSTR Endpoint, LPSTR NetworkOptions) { + RPC_STATUS status; + TRACE("(RpcBinding == ^%p, NetworkAddr == %s, EndPoint == %s, NetworkOptions == %s)\n", Binding, debugstr_a(NetworkAddr), debugstr_a(Endpoint), debugstr_a(NetworkOptions)); @@ -154,12 +156,20 @@ static RPC_STATUS RPCRT4_CompleteBindingA(RpcBinding* Binding, LPSTR NetworkAddr Binding->NetworkOptions = RPCRT4_strdupAtoW(NetworkOptions); if (!Binding->Endpoint) ERR("out of memory?\n"); + status = RPCRT4_GetAssociation(Binding->Protseq, Binding->NetworkAddr, + Binding->Endpoint, Binding->NetworkOptions, + &Binding->Assoc); + if (status != RPC_S_OK) + return status; + return RPC_S_OK; } static RPC_STATUS RPCRT4_CompleteBindingW(RpcBinding* Binding, LPWSTR NetworkAddr, LPWSTR Endpoint, LPWSTR NetworkOptions) { + RPC_STATUS status; + TRACE("(RpcBinding == ^%p, NetworkAddr == %s, EndPoint == %s, NetworkOptions == %s)\n", Binding, debugstr_w(NetworkAddr), debugstr_w(Endpoint), debugstr_w(NetworkOptions)); @@ -175,16 +185,32 @@ static RPC_STATUS RPCRT4_CompleteBindingW(RpcBinding* Binding, LPWSTR NetworkAdd HeapFree(GetProcessHeap(), 0, Binding->NetworkOptions); Binding->NetworkOptions = RPCRT4_strdupW(NetworkOptions); + status = RPCRT4_GetAssociation(Binding->Protseq, Binding->NetworkAddr, + Binding->Endpoint, Binding->NetworkOptions, + &Binding->Assoc); + if (status != RPC_S_OK) + return status; + return RPC_S_OK; } RPC_STATUS RPCRT4_ResolveBinding(RpcBinding* Binding, LPSTR Endpoint) { + RPC_STATUS status; + TRACE("(RpcBinding == ^%p, EndPoint == \"%s\"\n", Binding, Endpoint); RPCRT4_strfree(Binding->Endpoint); Binding->Endpoint = RPCRT4_strdupA(Endpoint); + RpcAssoc_Release(Binding->Assoc); + Binding->Assoc = NULL; + status = RPCRT4_GetAssociation(Binding->Protseq, Binding->NetworkAddr, + Binding->Endpoint, Binding->NetworkOptions, + &Binding->Assoc); + if (status != RPC_S_OK) + return status; + return RPC_S_OK; } @@ -226,7 +252,7 @@ RPC_STATUS RPCRT4_DestroyBinding(RpcBinding* Binding) return RPC_S_OK; TRACE("binding: %p\n", Binding); - /* FIXME: release connections */ + if (Binding->Assoc) RpcAssoc_Release(Binding->Assoc); RPCRT4_strfree(Binding->Endpoint); RPCRT4_strfree(Binding->NetworkAddr); RPCRT4_strfree(Binding->Protseq); @@ -248,9 +274,8 @@ RPC_STATUS RPCRT4_OpenBinding(RpcBinding* Binding, RpcConnection** Connection, if (!Binding->server) { /* try to find a compatible connection from the connection pool */ - NewConnection = RPCRT4_GetIdleConnection(InterfaceId, TransferSyntax, - Binding->Protseq, Binding->NetworkAddr, Binding->Endpoint, - Binding->AuthInfo, Binding->QOS); + NewConnection = RpcAssoc_GetIdleConnection(Binding->Assoc, InterfaceId, + TransferSyntax, Binding->AuthInfo, Binding->QOS); if (NewConnection) { *Connection = NewConnection; return RPC_S_OK; @@ -340,7 +365,7 @@ RPC_STATUS RPCRT4_CloseBinding(RpcBinding* Binding, RpcConnection* Connection) return RPCRT4_DestroyConnection(Connection); } else { - RPCRT4_ReleaseIdleConnection(Connection); + RpcAssoc_ReleaseIdleConnection(Binding->Assoc, Connection); return RPC_S_OK; } } @@ -881,6 +906,8 @@ RPC_STATUS RPC_ENTRY RpcBindingCopy( DestBinding->NetworkAddr = RPCRT4_strndupA(SrcBinding->NetworkAddr, -1); DestBinding->Endpoint = RPCRT4_strndupA(SrcBinding->Endpoint, -1); DestBinding->NetworkOptions = RPCRT4_strdupW(SrcBinding->NetworkOptions); + if (SrcBinding->Assoc) SrcBinding->Assoc->refs++; + DestBinding->Assoc = SrcBinding->Assoc; if (SrcBinding->AuthInfo) RpcAuthInfo_AddRef(SrcBinding->AuthInfo); DestBinding->AuthInfo = SrcBinding->AuthInfo; diff --git a/dlls/rpcrt4/rpc_binding.h b/dlls/rpcrt4/rpc_binding.h index 651d79ef4c5..2ad90f81313 100644 --- a/dlls/rpcrt4/rpc_binding.h +++ b/dlls/rpcrt4/rpc_binding.h @@ -43,6 +43,21 @@ typedef struct _RpcQualityOfService RPC_SECURITY_QOS_V2_W *qos; } RpcQualityOfService; +typedef struct _RpcAssoc +{ + struct list entry; /* entry in the global list of associations */ + LONG refs; + + LPSTR Protseq; + LPSTR NetworkAddr; + LPSTR Endpoint; + LPWSTR NetworkOptions; + RpcAuthInfo *AuthInfo; + + CRITICAL_SECTION cs; + struct list connection_pool; +} RpcAssoc; + struct connection_ops; typedef struct _RpcConnection @@ -97,6 +112,7 @@ typedef struct _RpcBinding RPC_BLOCKING_FN BlockingFn; ULONG ServerTid; RpcConnection* FromConn; + RpcAssoc *Assoc; /* authentication */ RpcAuthInfo *AuthInfo; @@ -117,8 +133,11 @@ ULONG RpcAuthInfo_Release(RpcAuthInfo *AuthInfo); ULONG RpcQualityOfService_AddRef(RpcQualityOfService *qos); ULONG RpcQualityOfService_Release(RpcQualityOfService *qos); -RpcConnection *RPCRT4_GetIdleConnection(const RPC_SYNTAX_IDENTIFIER *InterfaceId, const RPC_SYNTAX_IDENTIFIER *TransferSyntax, LPCSTR Protseq, LPCSTR NetworkAddr, LPCSTR Endpoint, const RpcAuthInfo* AuthInfo, const RpcQualityOfService *QOS); -void RPCRT4_ReleaseIdleConnection(RpcConnection *Connection); +RPC_STATUS RPCRT4_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr, LPCSTR Endpoint, LPCWSTR NetworkOptions, RpcAssoc **assoc); +RpcConnection *RpcAssoc_GetIdleConnection(RpcAssoc *assoc, const RPC_SYNTAX_IDENTIFIER *InterfaceId, const RPC_SYNTAX_IDENTIFIER *TransferSyntax, const RpcAuthInfo *AuthInfo, const RpcQualityOfService *QOS); +void RpcAssoc_ReleaseIdleConnection(RpcAssoc *assoc, RpcConnection *Connection); +ULONG RpcAssoc_Release(RpcAssoc *assoc); + RPC_STATUS RPCRT4_CreateConnection(RpcConnection** Connection, BOOL server, LPCSTR Protseq, LPCSTR NetworkAddr, LPCSTR Endpoint, LPCWSTR NetworkOptions, RpcAuthInfo* AuthInfo, RpcQualityOfService *QOS, RpcBinding* Binding); RPC_STATUS RPCRT4_DestroyConnection(RpcConnection* Connection); RPC_STATUS RPCRT4_OpenClientConnection(RpcConnection* Connection); diff --git a/dlls/rpcrt4/rpc_transport.c b/dlls/rpcrt4/rpc_transport.c index 5f3966fc536..e1d29bbec0c 100644 --- a/dlls/rpcrt4/rpc_transport.c +++ b/dlls/rpcrt4/rpc_transport.c @@ -80,16 +80,16 @@ WINE_DEFAULT_DEBUG_CHANNEL(rpc); -static CRITICAL_SECTION connection_pool_cs; -static CRITICAL_SECTION_DEBUG connection_pool_cs_debug = +static CRITICAL_SECTION assoc_list_cs; +static CRITICAL_SECTION_DEBUG assoc_list_cs_debug = { - 0, 0, &connection_pool_cs, - { &connection_pool_cs_debug.ProcessLocksList, &connection_pool_cs_debug.ProcessLocksList }, - 0, 0, { (DWORD_PTR)(__FILE__ ": connection_pool") } + 0, 0, &assoc_list_cs, + { &assoc_list_cs_debug.ProcessLocksList, &assoc_list_cs_debug.ProcessLocksList }, + 0, 0, { (DWORD_PTR)(__FILE__ ": assoc_list_cs") } }; -static CRITICAL_SECTION connection_pool_cs = { &connection_pool_cs_debug, -1, 0, 0, 0, 0 }; +static CRITICAL_SECTION assoc_list_cs = { &assoc_list_cs_debug, -1, 0, 0, 0, 0 }; -static struct list connection_pool = LIST_INIT(connection_pool); +static struct list assoc_list = LIST_INIT(assoc_list); /**** ncacn_np support ****/ @@ -1393,38 +1393,114 @@ RPC_STATUS RPCRT4_CreateConnection(RpcConnection** Connection, BOOL server, return RPC_S_OK; } -RpcConnection *RPCRT4_GetIdleConnection(const RPC_SYNTAX_IDENTIFIER *InterfaceId, - const RPC_SYNTAX_IDENTIFIER *TransferSyntax, LPCSTR Protseq, LPCSTR NetworkAddr, - LPCSTR Endpoint, const RpcAuthInfo* AuthInfo, const RpcQualityOfService *QOS) +RPC_STATUS RPCRT4_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr, + LPCSTR Endpoint, LPCWSTR NetworkOptions, + RpcAssoc **assoc_out) +{ + RpcAssoc *assoc; + + EnterCriticalSection(&assoc_list_cs); + LIST_FOR_EACH_ENTRY(assoc, &assoc_list, RpcAssoc, entry) + { + if (!strcmp(Protseq, assoc->Protseq) && + !strcmp(NetworkAddr, assoc->NetworkAddr) && + !strcmp(Endpoint, assoc->Endpoint) && + ((!assoc->NetworkOptions && !NetworkOptions) || !strcmpW(NetworkOptions, assoc->NetworkOptions))) + { + assoc->refs++; + *assoc_out = assoc; + LeaveCriticalSection(&assoc_list_cs); + TRACE("using existing assoc %p\n", assoc); + return RPC_S_OK; + } + } + + assoc = HeapAlloc(GetProcessHeap(), 0, sizeof(*assoc)); + if (!assoc) + { + LeaveCriticalSection(&assoc_list_cs); + return RPC_S_OUT_OF_RESOURCES; + } + assoc->refs = 1; + list_init(&assoc->connection_pool); + InitializeCriticalSection(&assoc->cs); + assoc->Protseq = RPCRT4_strdupA(Protseq); + assoc->NetworkAddr = RPCRT4_strdupA(NetworkAddr); + assoc->Endpoint = RPCRT4_strdupA(Endpoint); + assoc->NetworkOptions = NetworkOptions ? RPCRT4_strdupW(NetworkOptions) : NULL; + list_add_head(&assoc_list, &assoc->entry); + *assoc_out = assoc; + + LeaveCriticalSection(&assoc_list_cs); + + TRACE("new assoc %p\n", assoc); + + return RPC_S_OK; +} + +ULONG RpcAssoc_Release(RpcAssoc *assoc) +{ + ULONG refs; + + EnterCriticalSection(&assoc_list_cs); + refs = --assoc->refs; + if (!refs) + list_remove(&assoc->entry); + LeaveCriticalSection(&assoc_list_cs); + + if (!refs) + { + RpcConnection *Connection, *cursor2; + + TRACE("destroying assoc %p\n", assoc); + + LIST_FOR_EACH_ENTRY_SAFE(Connection, cursor2, &assoc->connection_pool, RpcConnection, conn_pool_entry) + { + list_remove(&Connection->conn_pool_entry); + RPCRT4_DestroyConnection(Connection); + } + + HeapFree(GetProcessHeap(), 0, assoc->NetworkOptions); + HeapFree(GetProcessHeap(), 0, assoc->Endpoint); + HeapFree(GetProcessHeap(), 0, assoc->NetworkAddr); + HeapFree(GetProcessHeap(), 0, assoc->Protseq); + + HeapFree(GetProcessHeap(), 0, assoc); + } + + return refs; +} + +RpcConnection *RpcAssoc_GetIdleConnection(RpcAssoc *assoc, + const RPC_SYNTAX_IDENTIFIER *InterfaceId, + const RPC_SYNTAX_IDENTIFIER *TransferSyntax, const RpcAuthInfo *AuthInfo, + const RpcQualityOfService *QOS) { RpcConnection *Connection; /* try to find a compatible connection from the connection pool */ - EnterCriticalSection(&connection_pool_cs); - LIST_FOR_EACH_ENTRY(Connection, &connection_pool, RpcConnection, conn_pool_entry) + EnterCriticalSection(&assoc->cs); + LIST_FOR_EACH_ENTRY(Connection, &assoc->connection_pool, RpcConnection, conn_pool_entry) if ((Connection->AuthInfo == AuthInfo) && (Connection->QOS == QOS) && !memcmp(&Connection->ActiveInterface, InterfaceId, - sizeof(RPC_SYNTAX_IDENTIFIER)) && - !strcmp(rpcrt4_conn_get_name(Connection), Protseq) && - !strcmp(Connection->NetworkAddr, NetworkAddr) && - !strcmp(Connection->Endpoint, Endpoint)) + sizeof(RPC_SYNTAX_IDENTIFIER))) { list_remove(&Connection->conn_pool_entry); - LeaveCriticalSection(&connection_pool_cs); + LeaveCriticalSection(&assoc->cs); TRACE("got connection from pool %p\n", Connection); return Connection; } - LeaveCriticalSection(&connection_pool_cs); + LeaveCriticalSection(&assoc->cs); return NULL; } -void RPCRT4_ReleaseIdleConnection(RpcConnection *Connection) +void RpcAssoc_ReleaseIdleConnection(RpcAssoc *assoc, RpcConnection *Connection) { assert(!Connection->server); - EnterCriticalSection(&connection_pool_cs); - list_add_head(&connection_pool, &Connection->conn_pool_entry); - LeaveCriticalSection(&connection_pool_cs); + EnterCriticalSection(&assoc->cs); + list_add_head(&assoc->connection_pool, &Connection->conn_pool_entry); + LeaveCriticalSection(&assoc->cs); }