rpcrt4: Try a lot harder to resuse existing connections by comparing inside the RpcQualityOfService and RpcAuthInfo objects.

Store a copy of the SEC_WINNT_AUTH_IDENTITY structure passed in to
RpcBindingSetAuthInfo(Ex) to enable us to do this for RpcAuthInfo objects.
This commit is contained in:
Rob Shearman 2007-06-25 14:27:17 +01:00 committed by Alexandre Julliard
parent 0ebcacca39
commit cbafe663b0
3 changed files with 192 additions and 7 deletions

View File

@ -78,6 +78,17 @@ LPWSTR RPCRT4_strdupAtoW(LPCSTR src)
return s; return s;
} }
static LPWSTR RPCRT4_strndupAtoW(LPCSTR src, INT slen)
{
DWORD len;
LPWSTR s;
if (!src) return NULL;
len = MultiByteToWideChar(CP_ACP, 0, src, slen, NULL, 0);
s = HeapAlloc(GetProcessHeap(), 0, len*sizeof(WCHAR));
MultiByteToWideChar(CP_ACP, 0, src, slen, s, len);
return s;
}
LPWSTR RPCRT4_strndupW(LPCWSTR src, INT slen) LPWSTR RPCRT4_strndupW(LPCWSTR src, INT slen)
{ {
DWORD len; DWORD len;
@ -967,9 +978,24 @@ RPC_STATUS WINAPI RpcRevertToSelfEx(RPC_BINDING_HANDLE BindingHandle)
return RPC_S_OK; return RPC_S_OK;
} }
static inline BOOL has_nt_auth_identity(ULONG AuthnLevel)
{
switch (AuthnLevel)
{
case RPC_C_AUTHN_GSS_NEGOTIATE:
case RPC_C_AUTHN_WINNT:
case RPC_C_AUTHN_GSS_KERBEROS:
return TRUE;
default:
return FALSE;
}
}
static RPC_STATUS RpcAuthInfo_Create(ULONG AuthnLevel, ULONG AuthnSvc, static RPC_STATUS RpcAuthInfo_Create(ULONG AuthnLevel, ULONG AuthnSvc,
CredHandle cred, TimeStamp exp, CredHandle cred, TimeStamp exp,
ULONG cbMaxToken, RpcAuthInfo **ret) ULONG cbMaxToken,
RPC_AUTH_IDENTITY_HANDLE identity,
RpcAuthInfo **ret)
{ {
RpcAuthInfo *AuthInfo = HeapAlloc(GetProcessHeap(), 0, sizeof(*AuthInfo)); RpcAuthInfo *AuthInfo = HeapAlloc(GetProcessHeap(), 0, sizeof(*AuthInfo));
if (!AuthInfo) if (!AuthInfo)
@ -981,6 +1007,51 @@ static RPC_STATUS RpcAuthInfo_Create(ULONG AuthnLevel, ULONG AuthnSvc,
AuthInfo->cred = cred; AuthInfo->cred = cred;
AuthInfo->exp = exp; AuthInfo->exp = exp;
AuthInfo->cbMaxToken = cbMaxToken; AuthInfo->cbMaxToken = cbMaxToken;
AuthInfo->identity = identity;
/* duplicate the SEC_WINNT_AUTH_IDENTITY structure, if applicable, to
* enable better matching in RpcAuthInfo_IsEqual */
if (identity && has_nt_auth_identity(AuthnSvc))
{
const SEC_WINNT_AUTH_IDENTITY_W *nt_identity = identity;
AuthInfo->nt_identity = HeapAlloc(GetProcessHeap(), 0, sizeof(*AuthInfo->nt_identity));
if (!AuthInfo->nt_identity)
{
HeapFree(GetProcessHeap(), 0, AuthInfo);
return ERROR_OUTOFMEMORY;
}
AuthInfo->nt_identity->Flags = SEC_WINNT_AUTH_IDENTITY_UNICODE;
if (nt_identity->Flags & SEC_WINNT_AUTH_IDENTITY_UNICODE)
AuthInfo->nt_identity->User = RPCRT4_strndupW(nt_identity->User, nt_identity->UserLength);
else
AuthInfo->nt_identity->User = RPCRT4_strndupAtoW((const char *)nt_identity->User, nt_identity->UserLength);
AuthInfo->nt_identity->UserLength = nt_identity->UserLength;
if (nt_identity->Flags & SEC_WINNT_AUTH_IDENTITY_UNICODE)
AuthInfo->nt_identity->Domain = RPCRT4_strndupW(nt_identity->Domain, nt_identity->DomainLength);
else
AuthInfo->nt_identity->Domain = RPCRT4_strndupAtoW((const char *)nt_identity->Domain, nt_identity->DomainLength);
AuthInfo->nt_identity->DomainLength = nt_identity->DomainLength;
if (nt_identity->Flags & SEC_WINNT_AUTH_IDENTITY_UNICODE)
AuthInfo->nt_identity->Password = RPCRT4_strndupW(nt_identity->Password, nt_identity->PasswordLength);
else
AuthInfo->nt_identity->Password = RPCRT4_strndupAtoW((const char *)nt_identity->Password, nt_identity->PasswordLength);
AuthInfo->nt_identity->PasswordLength = nt_identity->PasswordLength;
if (!AuthInfo->nt_identity->User ||
!AuthInfo->nt_identity->Domain ||
!AuthInfo->nt_identity->Password)
{
HeapFree(GetProcessHeap(), 0, AuthInfo->nt_identity->User);
HeapFree(GetProcessHeap(), 0, AuthInfo->nt_identity->Domain);
HeapFree(GetProcessHeap(), 0, AuthInfo->nt_identity->Password);
HeapFree(GetProcessHeap(), 0, AuthInfo->nt_identity);
HeapFree(GetProcessHeap(), 0, AuthInfo);
return ERROR_OUTOFMEMORY;
}
}
else
AuthInfo->nt_identity = NULL;
*ret = AuthInfo; *ret = AuthInfo;
return RPC_S_OK; return RPC_S_OK;
} }
@ -997,12 +1068,60 @@ ULONG RpcAuthInfo_Release(RpcAuthInfo *AuthInfo)
if (!refs) if (!refs)
{ {
FreeCredentialsHandle(&AuthInfo->cred); FreeCredentialsHandle(&AuthInfo->cred);
if (AuthInfo->nt_identity)
{
HeapFree(GetProcessHeap(), 0, AuthInfo->nt_identity->User);
HeapFree(GetProcessHeap(), 0, AuthInfo->nt_identity->Domain);
HeapFree(GetProcessHeap(), 0, AuthInfo->nt_identity->User);
HeapFree(GetProcessHeap(), 0, AuthInfo->nt_identity);
}
HeapFree(GetProcessHeap(), 0, AuthInfo); HeapFree(GetProcessHeap(), 0, AuthInfo);
} }
return refs; return refs;
} }
BOOL RpcAuthInfo_IsEqual(const RpcAuthInfo *AuthInfo1, const RpcAuthInfo *AuthInfo2)
{
if (AuthInfo1 == AuthInfo2)
return TRUE;
if (!AuthInfo1 || !AuthInfo2)
return FALSE;
if ((AuthInfo1->AuthnLevel != AuthInfo2->AuthnLevel) ||
(AuthInfo1->AuthnSvc != AuthInfo2->AuthnSvc))
return FALSE;
if (AuthInfo1->identity == AuthInfo2->identity)
return TRUE;
if (!AuthInfo1->identity || !AuthInfo2->identity)
return FALSE;
if (has_nt_auth_identity(AuthInfo1->AuthnSvc))
{
const SEC_WINNT_AUTH_IDENTITY_W *identity1 = AuthInfo1->nt_identity;
const SEC_WINNT_AUTH_IDENTITY_W *identity2 = AuthInfo2->nt_identity;
/* compare user names */
if (identity1->UserLength != identity2->UserLength ||
memcmp(identity1->User, identity2->User, identity1->UserLength))
return FALSE;
/* compare domain names */
if (identity1->DomainLength != identity2->DomainLength ||
memcmp(identity1->Domain, identity2->Domain, identity1->DomainLength))
return FALSE;
/* compare passwords */
if (identity1->PasswordLength != identity2->PasswordLength ||
memcmp(identity1->Password, identity2->Password, identity1->PasswordLength))
return FALSE;
}
else
return FALSE;
return TRUE;
}
static RPC_STATUS RpcQualityOfService_Create(const RPC_SECURITY_QOS *qos_src, BOOL unicode, RpcQualityOfService **qos_dst) static RPC_STATUS RpcQualityOfService_Create(const RPC_SECURITY_QOS *qos_src, BOOL unicode, RpcQualityOfService **qos_dst)
{ {
RpcQualityOfService *qos = HeapAlloc(GetProcessHeap(), 0, sizeof(*qos)); RpcQualityOfService *qos = HeapAlloc(GetProcessHeap(), 0, sizeof(*qos));
@ -1143,6 +1262,65 @@ ULONG RpcQualityOfService_Release(RpcQualityOfService *qos)
return refs; return refs;
} }
BOOL RpcQualityOfService_IsEqual(const RpcQualityOfService *qos1, const RpcQualityOfService *qos2)
{
if (qos1 == qos2)
return TRUE;
if (!qos1 || !qos2)
return FALSE;
TRACE("qos1 = { %ld %ld %ld %ld }, qos2 = { %ld %ld %ld %ld }\n",
qos1->qos->Capabilities, qos1->qos->IdentityTracking,
qos1->qos->ImpersonationType, qos1->qos->AdditionalSecurityInfoType,
qos2->qos->Capabilities, qos2->qos->IdentityTracking,
qos2->qos->ImpersonationType, qos2->qos->AdditionalSecurityInfoType);
if ((qos1->qos->Capabilities != qos2->qos->Capabilities) ||
(qos1->qos->IdentityTracking != qos2->qos->IdentityTracking) ||
(qos1->qos->ImpersonationType != qos2->qos->ImpersonationType) ||
(qos1->qos->AdditionalSecurityInfoType != qos2->qos->AdditionalSecurityInfoType))
return FALSE;
if (qos1->qos->AdditionalSecurityInfoType == RPC_C_AUTHN_INFO_TYPE_HTTP)
{
const RPC_HTTP_TRANSPORT_CREDENTIALS_W *http_credentials1 = qos1->qos->u.HttpCredentials;
const RPC_HTTP_TRANSPORT_CREDENTIALS_W *http_credentials2 = qos2->qos->u.HttpCredentials;
if (http_credentials1->Flags != http_credentials2->Flags)
return FALSE;
if (http_credentials1->AuthenticationTarget != http_credentials2->AuthenticationTarget)
return FALSE;
/* authentication schemes and server certificate subject not currently used */
if (http_credentials1->TransportCredentials != http_credentials2->TransportCredentials)
{
const SEC_WINNT_AUTH_IDENTITY_W *identity1 = http_credentials1->TransportCredentials;
const SEC_WINNT_AUTH_IDENTITY_W *identity2 = http_credentials2->TransportCredentials;
if (!identity1 || !identity2)
return FALSE;
/* compare user names */
if (identity1->UserLength != identity2->UserLength ||
memcmp(identity1->User, identity2->User, identity1->UserLength))
return FALSE;
/* compare domain names */
if (identity1->DomainLength != identity2->DomainLength ||
memcmp(identity1->Domain, identity2->Domain, identity1->DomainLength))
return FALSE;
/* compare passwords */
if (identity1->PasswordLength != identity2->PasswordLength ||
memcmp(identity1->Password, identity2->Password, identity1->PasswordLength))
return FALSE;
}
}
return TRUE;
}
/*********************************************************************** /***********************************************************************
* RpcRevertToSelf (RPCRT4.@) * RpcRevertToSelf (RPCRT4.@)
*/ */
@ -1317,7 +1495,7 @@ RpcBindingSetAuthInfoExA( RPC_BINDING_HANDLE Binding, RPC_CSTR ServerPrincName,
if (bind->AuthInfo) RpcAuthInfo_Release(bind->AuthInfo); if (bind->AuthInfo) RpcAuthInfo_Release(bind->AuthInfo);
bind->AuthInfo = NULL; bind->AuthInfo = NULL;
r = RpcAuthInfo_Create(AuthnLevel, AuthnSvc, cred, exp, cbMaxToken, r = RpcAuthInfo_Create(AuthnLevel, AuthnSvc, cred, exp, cbMaxToken,
&bind->AuthInfo); AuthIdentity, &bind->AuthInfo);
if (r != RPC_S_OK) if (r != RPC_S_OK)
FreeCredentialsHandle(&cred); FreeCredentialsHandle(&cred);
return RPC_S_OK; return RPC_S_OK;
@ -1433,7 +1611,7 @@ RpcBindingSetAuthInfoExW( RPC_BINDING_HANDLE Binding, RPC_WSTR ServerPrincName,
if (bind->AuthInfo) RpcAuthInfo_Release(bind->AuthInfo); if (bind->AuthInfo) RpcAuthInfo_Release(bind->AuthInfo);
bind->AuthInfo = NULL; bind->AuthInfo = NULL;
r = RpcAuthInfo_Create(AuthnLevel, AuthnSvc, cred, exp, cbMaxToken, r = RpcAuthInfo_Create(AuthnLevel, AuthnSvc, cred, exp, cbMaxToken,
&bind->AuthInfo); AuthIdentity, &bind->AuthInfo);
if (r != RPC_S_OK) if (r != RPC_S_OK)
FreeCredentialsHandle(&cred); FreeCredentialsHandle(&cred);
return RPC_S_OK; return RPC_S_OK;

View File

@ -35,6 +35,11 @@ typedef struct _RpcAuthInfo
CredHandle cred; CredHandle cred;
TimeStamp exp; TimeStamp exp;
ULONG cbMaxToken; ULONG cbMaxToken;
/* the auth identity pointer that the application passed us (freed by application) */
RPC_AUTH_IDENTITY_HANDLE *identity;
/* our copy of NT auth identity structure, if the authentication service
* takes an NT auth identity */
SEC_WINNT_AUTH_IDENTITY_W *nt_identity;
} RpcAuthInfo; } RpcAuthInfo;
typedef struct _RpcQualityOfService typedef struct _RpcQualityOfService
@ -137,8 +142,10 @@ void RPCRT4_strfree(LPSTR src);
ULONG RpcAuthInfo_AddRef(RpcAuthInfo *AuthInfo); ULONG RpcAuthInfo_AddRef(RpcAuthInfo *AuthInfo);
ULONG RpcAuthInfo_Release(RpcAuthInfo *AuthInfo); ULONG RpcAuthInfo_Release(RpcAuthInfo *AuthInfo);
BOOL RpcAuthInfo_IsEqual(const RpcAuthInfo *AuthInfo1, const RpcAuthInfo *AuthInfo2);
ULONG RpcQualityOfService_AddRef(RpcQualityOfService *qos); ULONG RpcQualityOfService_AddRef(RpcQualityOfService *qos);
ULONG RpcQualityOfService_Release(RpcQualityOfService *qos); ULONG RpcQualityOfService_Release(RpcQualityOfService *qos);
BOOL RpcQualityOfService_IsEqual(const RpcQualityOfService *qos1, const RpcQualityOfService *qos2);
RPC_STATUS RPCRT4_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr, LPCSTR Endpoint, LPCWSTR NetworkOptions, RpcAssoc **assoc); RPC_STATUS RPCRT4_GetAssociation(LPCSTR Protseq, LPCSTR NetworkAddr, LPCSTR Endpoint, LPCWSTR NetworkOptions, RpcAssoc **assoc);
RpcConnection *RpcAssoc_GetIdleConnection(RpcAssoc *assoc, const RPC_SYNTAX_IDENTIFIER *InterfaceId, const RPC_SYNTAX_IDENTIFIER *TransferSyntax, const RpcAuthInfo *AuthInfo, const RpcQualityOfService *QOS); RpcConnection *RpcAssoc_GetIdleConnection(RpcAssoc *assoc, const RPC_SYNTAX_IDENTIFIER *InterfaceId, const RPC_SYNTAX_IDENTIFIER *TransferSyntax, const RpcAuthInfo *AuthInfo, const RpcQualityOfService *QOS);

View File

@ -1485,10 +1485,10 @@ RpcConnection *RpcAssoc_GetIdleConnection(RpcAssoc *assoc,
/* try to find a compatible connection from the connection pool */ /* try to find a compatible connection from the connection pool */
EnterCriticalSection(&assoc->cs); EnterCriticalSection(&assoc->cs);
LIST_FOR_EACH_ENTRY(Connection, &assoc->connection_pool, RpcConnection, conn_pool_entry) LIST_FOR_EACH_ENTRY(Connection, &assoc->connection_pool, RpcConnection, conn_pool_entry)
if ((Connection->AuthInfo == AuthInfo) && if (!memcmp(&Connection->ActiveInterface, InterfaceId,
(Connection->QOS == QOS) && sizeof(RPC_SYNTAX_IDENTIFIER)) &&
!memcmp(&Connection->ActiveInterface, InterfaceId, RpcAuthInfo_IsEqual(Connection->AuthInfo, AuthInfo) &&
sizeof(RPC_SYNTAX_IDENTIFIER))) RpcQualityOfService_IsEqual(Connection->QOS, QOS))
{ {
list_remove(&Connection->conn_pool_entry); list_remove(&Connection->conn_pool_entry);
LeaveCriticalSection(&assoc->cs); LeaveCriticalSection(&assoc->cs);