From fc5089349de7ec440a31a12cc1bf4199e76f4650 Mon Sep 17 00:00:00 2001 From: Aric Stewart Date: Mon, 12 Oct 2009 14:24:18 -0500 Subject: [PATCH] wininet: Cache basic authentication values based on realm and host. --- dlls/wininet/http.c | 221 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 177 insertions(+), 44 deletions(-) diff --git a/dlls/wininet/http.c b/dlls/wininet/http.c index f8749206f9f..d8db8af0d3d 100644 --- a/dlls/wininet/http.c +++ b/dlls/wininet/http.c @@ -177,6 +177,27 @@ struct gzip_stream_t { BOOL end_of_data; }; +typedef struct _authorizationData +{ + struct list entry; + + LPWSTR lpszwHost; + LPWSTR lpszwRealm; + LPSTR lpszAuthorization; + UINT AuthorizationLen; +} authorizationData; + +static struct list basicAuthorizationCache = LIST_INIT(basicAuthorizationCache); + +static CRITICAL_SECTION authcache_cs; +static CRITICAL_SECTION_DEBUG critsect_debug = +{ + 0, 0, &authcache_cs, + { &critsect_debug.ProcessLocksList, &critsect_debug.ProcessLocksList }, + 0, 0, { (DWORD_PTR)(__FILE__ ": authcache_cs") } +}; +static CRITICAL_SECTION authcache_cs = { &critsect_debug, -1, 0, 0, 0, 0 }; + static BOOL HTTP_OpenConnection(http_request_t *req); static BOOL HTTP_GetResponseHeaders(http_request_t *req, BOOL clear); static BOOL HTTP_ProcessHeader(http_request_t *req, LPCWSTR field, LPCWSTR value, DWORD dwModifier); @@ -496,11 +517,59 @@ static void HTTP_ProcessCookies( http_request_t *lpwhr ) } } -static inline BOOL is_basic_auth_value( LPCWSTR pszAuthValue ) +static void strip_spaces(LPWSTR start) +{ + LPWSTR str = start; + LPWSTR end; + + while (*str == ' ' && *str != '\0') + str++; + + if (str != start) + memmove(start, str, sizeof(WCHAR) * (strlenW(str) + 1)); + + end = start + strlenW(start) - 1; + while (end >= start && *end == ' ') + { + *end = '\0'; + end--; + } +} + +static inline BOOL is_basic_auth_value( LPCWSTR pszAuthValue, LPWSTR *pszRealm ) { static const WCHAR szBasic[] = {'B','a','s','i','c'}; /* Note: not nul-terminated */ - return !strncmpiW(pszAuthValue, szBasic, ARRAYSIZE(szBasic)) && + static const WCHAR szRealm[] = {'r','e','a','l','m'}; /* Note: not nul-terminated */ + BOOL is_basic; + is_basic = !strncmpiW(pszAuthValue, szBasic, ARRAYSIZE(szBasic)) && ((pszAuthValue[ARRAYSIZE(szBasic)] == ' ') || !pszAuthValue[ARRAYSIZE(szBasic)]); + if (is_basic && pszRealm) + { + LPCWSTR token; + LPCWSTR ptr = &pszAuthValue[ARRAYSIZE(szBasic)]; + LPCWSTR realm; + ptr++; + *pszRealm=NULL; + token = strchrW(ptr,'='); + if (!token) + return TRUE; + realm = ptr; + while (*realm == ' ' && *realm != '\0') + realm++; + if(!strncmpiW(realm, szRealm, ARRAYSIZE(szRealm)) && + (realm[ARRAYSIZE(szRealm)] == ' ' || realm[ARRAYSIZE(szRealm)] == '=')) + { + token++; + while (*token == ' ' && *token != '\0') + token++; + if (*token == '\0') + return TRUE; + *pszRealm = heap_strdupW(token); + strip_spaces(*pszRealm); + } + } + + return is_basic; } static void destroy_authinfo( struct HttpAuthInfo *authinfo ) @@ -517,13 +586,78 @@ static void destroy_authinfo( struct HttpAuthInfo *authinfo ) HeapFree(GetProcessHeap(), 0, authinfo); } +static UINT retrieve_cached_basic_authorization(LPWSTR host, LPWSTR realm, LPSTR *auth_data) +{ + authorizationData *ad; + UINT rc = 0; + + TRACE("Looking for authorization for %s:%s\n",debugstr_w(host),debugstr_w(realm)); + + EnterCriticalSection(&authcache_cs); + LIST_FOR_EACH_ENTRY(ad, &basicAuthorizationCache, authorizationData, entry) + { + if (!strcmpiW(host,ad->lpszwHost) && !strcmpW(realm,ad->lpszwRealm)) + { + TRACE("Authorization found in cache\n"); + *auth_data = HeapAlloc(GetProcessHeap(),0,ad->AuthorizationLen); + memcpy(*auth_data,ad->lpszAuthorization,ad->AuthorizationLen); + rc = ad->AuthorizationLen; + break; + } + } + LeaveCriticalSection(&authcache_cs); + return rc; +} + +static void cache_basic_authorization(LPWSTR host, LPWSTR realm, LPSTR auth_data, UINT auth_data_len) +{ + struct list *cursor; + authorizationData* ad = NULL; + + TRACE("caching authorization for %s:%s = %s\n",debugstr_w(host),debugstr_w(realm),debugstr_an(auth_data,auth_data_len)); + + EnterCriticalSection(&authcache_cs); + LIST_FOR_EACH(cursor, &basicAuthorizationCache) + { + authorizationData *check = LIST_ENTRY(cursor,authorizationData,entry); + if (!strcmpiW(host,check->lpszwHost) && !strcmpW(realm,check->lpszwRealm)) + { + ad = check; + break; + } + } + + if (ad) + { + TRACE("Found match in cache, replacing\n"); + HeapFree(GetProcessHeap(),0,ad->lpszAuthorization); + ad->lpszAuthorization = HeapAlloc(GetProcessHeap(),0,auth_data_len); + memcpy(ad->lpszAuthorization, auth_data, auth_data_len); + ad->AuthorizationLen = auth_data_len; + } + else + { + ad = HeapAlloc(GetProcessHeap(),0,sizeof(authorizationData)); + ad->lpszwHost = heap_strdupW(host); + ad->lpszwRealm = heap_strdupW(realm); + ad->lpszAuthorization = HeapAlloc(GetProcessHeap(),0,auth_data_len); + memcpy(ad->lpszAuthorization, auth_data, auth_data_len); + ad->AuthorizationLen = auth_data_len; + list_add_head(&basicAuthorizationCache,&ad->entry); + TRACE("authorization cached\n"); + } + LeaveCriticalSection(&authcache_cs); +} + static BOOL HTTP_DoAuthorization( http_request_t *lpwhr, LPCWSTR pszAuthValue, struct HttpAuthInfo **ppAuthInfo, - LPWSTR domain_and_username, LPWSTR password ) + LPWSTR domain_and_username, LPWSTR password, + LPWSTR host ) { SECURITY_STATUS sec_status; struct HttpAuthInfo *pAuthInfo = *ppAuthInfo; BOOL first = FALSE; + LPWSTR szRealm = NULL; TRACE("%s\n", debugstr_w(pszAuthValue)); @@ -544,7 +678,7 @@ static BOOL HTTP_DoAuthorization( http_request_t *lpwhr, LPCWSTR pszAuthValue, pAuthInfo->auth_data_len = 0; pAuthInfo->finished = FALSE; - if (is_basic_auth_value(pszAuthValue)) + if (is_basic_auth_value(pszAuthValue,NULL)) { static const WCHAR szBasic[] = {'B','a','s','i','c',0}; pAuthInfo->scheme = heap_strdupW(szBasic); @@ -631,33 +765,50 @@ static BOOL HTTP_DoAuthorization( http_request_t *lpwhr, LPCWSTR pszAuthValue, return FALSE; } - if (is_basic_auth_value(pszAuthValue)) + if (is_basic_auth_value(pszAuthValue,&szRealm)) { int userlen; int passlen; - char *auth_data; + char *auth_data = NULL; + UINT auth_data_len = 0; - TRACE("basic authentication\n"); + TRACE("basic authentication realm %s\n",debugstr_w(szRealm)); - /* we don't cache credentials for basic authentication, so we can't - * retrieve them if the application didn't pass us any credentials */ - if (!domain_and_username) return FALSE; + if (!domain_and_username) + { + if (host && szRealm) + auth_data_len = retrieve_cached_basic_authorization(host, szRealm,&auth_data); + if (auth_data_len == 0) + { + HeapFree(GetProcessHeap(),0,szRealm); + return FALSE; + } + } + else + { + userlen = WideCharToMultiByte(CP_UTF8, 0, domain_and_username, lstrlenW(domain_and_username), NULL, 0, NULL, NULL); + passlen = WideCharToMultiByte(CP_UTF8, 0, password, lstrlenW(password), NULL, 0, NULL, NULL); - userlen = WideCharToMultiByte(CP_UTF8, 0, domain_and_username, lstrlenW(domain_and_username), NULL, 0, NULL, NULL); - passlen = WideCharToMultiByte(CP_UTF8, 0, password, lstrlenW(password), NULL, 0, NULL, NULL); + /* length includes a nul terminator, which will be re-used for the ':' */ + auth_data = HeapAlloc(GetProcessHeap(), 0, userlen + 1 + passlen); + if (!auth_data) + { + HeapFree(GetProcessHeap(),0,szRealm); + return FALSE; + } - /* length includes a nul terminator, which will be re-used for the ':' */ - auth_data = HeapAlloc(GetProcessHeap(), 0, userlen + 1 + passlen); - if (!auth_data) - return FALSE; - - WideCharToMultiByte(CP_UTF8, 0, domain_and_username, -1, auth_data, userlen, NULL, NULL); - auth_data[userlen] = ':'; - WideCharToMultiByte(CP_UTF8, 0, password, -1, &auth_data[userlen+1], passlen, NULL, NULL); + WideCharToMultiByte(CP_UTF8, 0, domain_and_username, -1, auth_data, userlen, NULL, NULL); + auth_data[userlen] = ':'; + WideCharToMultiByte(CP_UTF8, 0, password, -1, &auth_data[userlen+1], passlen, NULL, NULL); + auth_data_len = userlen + 1 + passlen; + if (host && szRealm) + cache_basic_authorization(host, szRealm, auth_data, auth_data_len); + } pAuthInfo->auth_data = auth_data; - pAuthInfo->auth_data_len = userlen + 1 + passlen; + pAuthInfo->auth_data_len = auth_data_len; pAuthInfo->finished = TRUE; + HeapFree(GetProcessHeap(),0,szRealm); return TRUE; } @@ -3861,13 +4012,15 @@ BOOL WINAPI HTTP_HttpSendRequestW(http_request_t *lpwhr, LPCWSTR lpszHeaders, dwBufferSize=2048; if (dwStatusCode == HTTP_STATUS_DENIED) { + LPHTTPHEADERW Host = HTTP_GetHeader(lpwhr, hostW); DWORD dwIndex = 0; while (HTTP_HttpQueryInfoW(lpwhr,HTTP_QUERY_WWW_AUTHENTICATE,szAuthValue,&dwBufferSize,&dwIndex)) { if (HTTP_DoAuthorization(lpwhr, szAuthValue, &lpwhr->pAuthInfo, lpwhr->lpHttpSession->lpszUserName, - lpwhr->lpHttpSession->lpszPassword)) + lpwhr->lpHttpSession->lpszPassword, + Host->lpszValue)) { loop_next = TRUE; break; @@ -3882,7 +4035,8 @@ BOOL WINAPI HTTP_HttpSendRequestW(http_request_t *lpwhr, LPCWSTR lpszHeaders, if (HTTP_DoAuthorization(lpwhr, szAuthValue, &lpwhr->pProxyAuthInfo, lpwhr->lpHttpSession->lpAppInfo->lpszProxyUsername, - lpwhr->lpHttpSession->lpAppInfo->lpszProxyPassword)) + lpwhr->lpHttpSession->lpAppInfo->lpszProxyPassword, + NULL)) { loop_next = TRUE; break; @@ -4420,27 +4574,6 @@ lend: } } - -static void strip_spaces(LPWSTR start) -{ - LPWSTR str = start; - LPWSTR end; - - while (*str == ' ' && *str != '\0') - str++; - - if (str != start) - memmove(start, str, sizeof(WCHAR) * (strlenW(str) + 1)); - - end = start + strlenW(start) - 1; - while (end >= start && *end == ' ') - { - *end = '\0'; - end--; - } -} - - /*********************************************************************** * HTTP_InterpretHttpHeader (internal) *