diff --git a/dlls/secur32/ntlm.c b/dlls/secur32/ntlm.c index 423a1674d52..5b95d92b9fa 100644 --- a/dlls/secur32/ntlm.c +++ b/dlls/secur32/ntlm.c @@ -24,11 +24,13 @@ #include "windef.h" #include "winbase.h" #include "winnls.h" +#include "wincred.h" #include "rpc.h" #include "sspi.h" #include "lm.h" #include "secur32_priv.h" #include "hmac_md5.h" +#include "wine/unicode.h" #include "wine/debug.h" WINE_DEFAULT_DEBUG_CHANNEL(ntlm); @@ -93,6 +95,42 @@ static SECURITY_STATUS SEC_ENTRY ntlm_QueryCredentialsAttributesW( return ret; } +static char *ntlm_GetUsernameArg(LPCWSTR userW, INT userW_length) +{ + static const char username_arg[] = "--username="; + char *user; + int unixcp_size; + + unixcp_size = WideCharToMultiByte(CP_UNIXCP, WC_NO_BEST_FIT_CHARS, + userW, userW_length, NULL, 0, NULL, NULL) + sizeof(username_arg); + user = HeapAlloc(GetProcessHeap(), 0, unixcp_size); + if (!user) return NULL; + memcpy(user, username_arg, sizeof(username_arg) - 1); + WideCharToMultiByte(CP_UNIXCP, WC_NO_BEST_FIT_CHARS, userW, userW_length, + user + sizeof(username_arg) - 1, + unixcp_size - sizeof(username_arg) + 1, NULL, NULL); + user[unixcp_size - 1] = '\0'; + return user; +} + +static char *ntlm_GetDomainArg(LPCWSTR domainW, INT domainW_length) +{ + static const char domain_arg[] = "--domain="; + char *domain; + int unixcp_size; + + unixcp_size = WideCharToMultiByte(CP_UNIXCP, WC_NO_BEST_FIT_CHARS, + domainW, domainW_length, NULL, 0, NULL, NULL) + sizeof(domain_arg); + domain = HeapAlloc(GetProcessHeap(), 0, unixcp_size); + if (!domain) return NULL; + memcpy(domain, domain_arg, sizeof(domain_arg) - 1); + WideCharToMultiByte(CP_UNIXCP, WC_NO_BEST_FIT_CHARS, domainW, + domainW_length, domain + sizeof(domain_arg) - 1, + unixcp_size - sizeof(domain) + 1, NULL, NULL); + domain[unixcp_size - 1] = '\0'; + return domain; +} + /*********************************************************************** * AcquireCredentialsHandleW */ @@ -129,10 +167,6 @@ static SECURITY_STATUS SEC_ENTRY ntlm_AcquireCredentialsHandleW( break; case SECPKG_CRED_OUTBOUND: { - static const char username_arg[] = "--username="; - static const char domain_arg[] = "--domain="; - int unixcp_size; - ntlm_cred = HeapAlloc(GetProcessHeap(), 0, sizeof(*ntlm_cred)); if (!ntlm_cred) { @@ -153,24 +187,8 @@ static SECURITY_STATUS SEC_ENTRY ntlm_AcquireCredentialsHandleW( TRACE("Username is %s\n", debugstr_wn(auth_data->User, auth_data->UserLength)); TRACE("Domain name is %s\n", debugstr_wn(auth_data->Domain, auth_data->DomainLength)); - /* Get username and domain from pAuthData */ - unixcp_size = WideCharToMultiByte(CP_UNIXCP, WC_NO_BEST_FIT_CHARS, - auth_data->User, auth_data->UserLength, NULL, 0, NULL, NULL) + sizeof(username_arg); - ntlm_cred->username_arg = HeapAlloc(GetProcessHeap(), 0, unixcp_size); - memcpy(ntlm_cred->username_arg, username_arg, sizeof(username_arg) - 1); - WideCharToMultiByte(CP_UNIXCP, WC_NO_BEST_FIT_CHARS, auth_data->User, auth_data->UserLength, - ntlm_cred->username_arg + sizeof(username_arg) - 1, - unixcp_size - sizeof(username_arg) + 1, NULL, NULL); - ntlm_cred->username_arg[unixcp_size - 1] = '\0'; - - unixcp_size = WideCharToMultiByte(CP_UNIXCP, WC_NO_BEST_FIT_CHARS, - auth_data->Domain, auth_data->DomainLength, NULL, 0, NULL, NULL) + sizeof(domain_arg); - ntlm_cred->domain_arg = HeapAlloc(GetProcessHeap(), 0, unixcp_size); - memcpy(ntlm_cred->domain_arg, domain_arg, sizeof(domain_arg) - 1); - WideCharToMultiByte(CP_UNIXCP, WC_NO_BEST_FIT_CHARS, auth_data->Domain, - auth_data->DomainLength, ntlm_cred->domain_arg + sizeof(domain_arg) - 1, - unixcp_size - sizeof(domain) + 1, NULL, NULL); - ntlm_cred->domain_arg[unixcp_size - 1] = '\0'; + ntlm_cred->username_arg = ntlm_GetUsernameArg(auth_data->User, auth_data->UserLength); + ntlm_cred->domain_arg = ntlm_GetDomainArg(auth_data->Domain, auth_data->DomainLength); if(auth_data->PasswordLength != 0) { @@ -362,6 +380,46 @@ static int ntlm_GetDataBufferIndex(PSecBufferDesc pMessage) return -1; } +static BOOL ntlm_GetCachedCredential(const SEC_WCHAR *pszTargetName, PCREDENTIALW *cred) +{ + LPCWSTR p; + LPCWSTR pszHost; + LPWSTR pszHostOnly; + BOOL ret; + + if (!pszTargetName) + return FALSE; + + /* try to get the start of the hostname from service principal name (SPN) */ + pszHost = strchrW(pszTargetName, '/'); + if (pszHost) + { + /* skip slash character */ + pszHost++; + + /* find end of host by detecting start of instance port or start of referrer */ + p = strchrW(pszHost, ':'); + if (!p) + p = strchrW(pszHost, '/'); + if (!p) + p = pszHost + strlenW(pszHost); + } + else /* otherwise not an SPN, just a host */ + p = pszHost + strlenW(pszHost); + + pszHostOnly = HeapAlloc(GetProcessHeap(), 0, (p - pszHost + 1) * sizeof(WCHAR)); + if (!pszHostOnly) + return FALSE; + + memcpy(pszHostOnly, pszHost, (p - pszHost) * sizeof(WCHAR)); + pszHostOnly[p - pszHost] = '\0'; + + ret = CredReadW(pszHostOnly, CRED_TYPE_DOMAIN_PASSWORD, 0, cred); + + HeapFree(GetProcessHeap(), 0, pszHostOnly); + return ret; +} + /*********************************************************************** * InitializeSecurityContextW */ @@ -380,6 +438,8 @@ static SECURITY_STATUS SEC_ENTRY ntlm_InitializeSecurityContextW( int buffer_len, bin_len, max_len = NTLM_MAX_BUF; int token_idx; SEC_CHAR *username = NULL; + SEC_CHAR *domain = NULL; + SEC_CHAR *password = NULL; TRACE("%p %p %s %d %d %d %p %d %p %p %p %p\n", phCredential, phContext, debugstr_w(pszTargetName), fContextReq, Reserved1, TargetDataRep, pInput, @@ -399,11 +459,6 @@ static SECURITY_STATUS SEC_ENTRY ntlm_InitializeSecurityContextW( */ /* The squid cache size is 2010 chars, and that's what ntlm_auth uses */ - if (pszTargetName) - { - TRACE("According to a MS whitepaper pszTargetName is ignored.\n"); - } - if(TargetDataRep == SECURITY_NETWORK_DREP){ TRACE("Setting SECURITY_NETWORK_DREP\n"); } @@ -416,6 +471,7 @@ static SECURITY_STATUS SEC_ENTRY ntlm_InitializeSecurityContextW( static char helper_protocol[] = "--helper-protocol=ntlmssp-client-1"; static CHAR credentials_argv[] = "--use-cached-creds"; SEC_CHAR *client_argv[5]; + int pwlen = 0; TRACE("First time in ISC()\n"); @@ -442,30 +498,62 @@ static SECURITY_STATUS SEC_ENTRY ntlm_InitializeSecurityContextW( { LPWKSTA_USER_INFO_1 ui = NULL; NET_API_STATUS status; - int unixcp_size; - static const char username_arg[] = "--username="; + PCREDENTIALW cred; - status = NetWkstaUserGetInfo(NULL, 1, (LPBYTE *)&ui); - if (status != NERR_Success || ui == NULL) + if (ntlm_GetCachedCredential(pszTargetName, &cred)) { - ret = SEC_E_NO_CREDENTIALS; - goto isc_end; + LPWSTR p; + p = strchrW(cred->UserName, '\\'); + if (p) + { + domain = ntlm_GetDomainArg(cred->UserName, p - cred->UserName); + p++; + } + else + { + domain = ntlm_GetDomainArg(NULL, 0); + p = cred->UserName; + } + + username = ntlm_GetUsernameArg(p, -1); + + if(cred->CredentialBlobSize != 0) + { + pwlen = WideCharToMultiByte(CP_UNIXCP, + WC_NO_BEST_FIT_CHARS, (LPWSTR)cred->CredentialBlob, + cred->CredentialBlobSize / sizeof(WCHAR), NULL, 0, + NULL, NULL); + + password = HeapAlloc(GetProcessHeap(), 0, pwlen); + + WideCharToMultiByte(CP_UNIXCP, WC_NO_BEST_FIT_CHARS, + (LPWSTR)cred->CredentialBlob, + cred->CredentialBlobSize / sizeof(WCHAR), + password, pwlen, NULL, NULL); + } + + CredFree(cred); + + client_argv[2] = username; + client_argv[3] = domain; + client_argv[4] = NULL; } + else + { + status = NetWkstaUserGetInfo(NULL, 1, (LPBYTE *)&ui); + if (status != NERR_Success || ui == NULL) + { + ret = SEC_E_NO_CREDENTIALS; + goto isc_end; + } + username = ntlm_GetUsernameArg(ui->wkui1_username, -1); - unixcp_size = WideCharToMultiByte(CP_UNIXCP, WC_NO_BEST_FIT_CHARS, - ui->wkui1_username, -1, NULL, 0, NULL, NULL) + sizeof(username_arg); - username = HeapAlloc(GetProcessHeap(), 0, unixcp_size); - memcpy(username, username_arg, sizeof(username_arg) - 1); - WideCharToMultiByte(CP_UNIXCP, WC_NO_BEST_FIT_CHARS, ui->wkui1_username, -1, - username + sizeof(username_arg) - 1, - unixcp_size - sizeof(username_arg) + 1, NULL, NULL); - username[unixcp_size - 1] = '\0'; + TRACE("using cached credentials\n"); - TRACE("using cached credentials\n"); - - client_argv[2] = username; - client_argv[3] = credentials_argv; - client_argv[4] = NULL; + client_argv[2] = username; + client_argv[3] = credentials_argv; + client_argv[4] = NULL; + } } else { @@ -487,19 +575,20 @@ static SECURITY_STATUS SEC_ENTRY ntlm_InitializeSecurityContextW( } /* Generate the dummy session key = MD4(MD4(password))*/ - if(ntlm_cred->password) + if(password || ntlm_cred->password) { SEC_WCHAR *unicode_password; int passwd_lenW; TRACE("Converting password to unicode.\n"); passwd_lenW = MultiByteToWideChar(CP_ACP, 0, - (LPCSTR)ntlm_cred->password, ntlm_cred->pwlen, + password ? (LPCSTR)password : (LPCSTR)ntlm_cred->password, + password ? pwlen : ntlm_cred->pwlen, NULL, 0); unicode_password = HeapAlloc(GetProcessHeap(), 0, passwd_lenW * sizeof(SEC_WCHAR)); - MultiByteToWideChar(CP_ACP, 0, (LPCSTR)ntlm_cred->password, - ntlm_cred->pwlen, unicode_password, passwd_lenW); + MultiByteToWideChar(CP_ACP, 0, password ? (LPCSTR)password : (LPCSTR)ntlm_cred->password, + password ? pwlen : ntlm_cred->pwlen, unicode_password, passwd_lenW); SECUR32_CreateNTLMv1SessionKey((PBYTE)unicode_password, passwd_lenW * sizeof(SEC_WCHAR), helper->session_key); @@ -560,7 +649,7 @@ static SECURITY_STATUS SEC_ENTRY ntlm_InitializeSecurityContextW( /* If no password is given, try to use cached credentials. Fall back to an empty * password if this failed. */ - if(ntlm_cred->password == NULL) + if(!password && !ntlm_cred->password) { lstrcpynA(buffer, "OK", max_len-1); if((ret = run_helper(helper, buffer, max_len, &buffer_len)) != SEC_E_OK) @@ -581,8 +670,8 @@ static SECURITY_STATUS SEC_ENTRY ntlm_InitializeSecurityContextW( else { lstrcpynA(buffer, "PW ", max_len-1); - if((ret = encodeBase64((unsigned char*)ntlm_cred->password, - ntlm_cred->pwlen, buffer+3, + if((ret = encodeBase64(password ? (unsigned char *)password : (unsigned char *)ntlm_cred->password, + password ? pwlen : ntlm_cred->pwlen, buffer+3, max_len-3, &buffer_len)) != SEC_E_OK) { cleanup_helper(helper); @@ -836,6 +925,8 @@ static SECURITY_STATUS SEC_ENTRY ntlm_InitializeSecurityContextW( isc_end: HeapFree(GetProcessHeap(), 0, username); + HeapFree(GetProcessHeap(), 0, domain); + HeapFree(GetProcessHeap(), 0, password); HeapFree(GetProcessHeap(), 0, want_flags); HeapFree(GetProcessHeap(), 0, buffer); HeapFree(GetProcessHeap(), 0, bin);