diff --git a/dlls/secur32/ntlm.c b/dlls/secur32/ntlm.c index 9a1fa3b9067..5a00c7ad29a 100644 --- a/dlls/secur32/ntlm.c +++ b/dlls/secur32/ntlm.c @@ -370,6 +370,26 @@ static SECURITY_STATUS SEC_ENTRY ntlm_AcquireCredentialsHandleA( return ret; } +/************************************************************************* + * ntlm_GetTokenBufferIndex + * Calculates the index of the secbuffer with BufferType == SECBUFFER_TOKEN + * Returns index if found or -1 if not found. + */ +static int ntlm_GetTokenBufferIndex(PSecBufferDesc pMessage) +{ + UINT i; + + TRACE("%p\n", pMessage); + + for( i = 0; i < pMessage->cBuffers; ++i ) + { + if(pMessage->pBuffers[i].BufferType == SECBUFFER_TOKEN) + return i; + } + + return -1; +} + /*********************************************************************** * InitializeSecurityContextW */ @@ -385,6 +405,7 @@ static SECURITY_STATUS SEC_ENTRY ntlm_InitializeSecurityContextW( char* buffer, *want_flags = NULL; PBYTE bin; int buffer_len, bin_len, max_len = NTLM_MAX_BUF; + int token_idx; TRACE("%p %p %s %d %d %d %p %d %p %p %p %p\n", phCredential, phContext, debugstr_w(pszTargetName), fContextReq, Reserved1, TargetDataRep, pInput, @@ -556,11 +577,13 @@ static SECURITY_STATUS SEC_ENTRY ntlm_InitializeSecurityContextW( } else { + int input_token_idx; + /* handle second call here */ /* encode server data to base64 */ - if (!pInput || !pInput->cBuffers) + if (!pInput || ((input_token_idx = ntlm_GetTokenBufferIndex(pInput)) == -1)) { - ret = SEC_E_INCOMPLETE_MESSAGE; + ret = SEC_E_INVALID_TOKEN; goto isc_end; } @@ -577,23 +600,23 @@ static SECURITY_STATUS SEC_ENTRY ntlm_InitializeSecurityContextW( return SEC_E_INVALID_HANDLE; } - if (!pInput->pBuffers[0].pvBuffer) + if (!pInput->pBuffers[input_token_idx].pvBuffer) { ret = SEC_E_INTERNAL_ERROR; goto isc_end; } - if(pInput->pBuffers[0].cbBuffer > max_len) + if(pInput->pBuffers[input_token_idx].cbBuffer > max_len) { TRACE("pInput->pBuffers[0].cbBuffer is: %ld\n", - pInput->pBuffers[0].cbBuffer); + pInput->pBuffers[input_token_idx].cbBuffer); ret = SEC_E_INVALID_TOKEN; goto isc_end; } else - bin_len = pInput->pBuffers[0].cbBuffer; + bin_len = pInput->pBuffers[input_token_idx].cbBuffer; - memcpy(bin, pInput->pBuffers[0].pvBuffer, bin_len); + memcpy(bin, pInput->pBuffers[input_token_idx].pvBuffer, bin_len); lstrcpynA(buffer, "TT ", max_len-1); @@ -632,32 +655,34 @@ static SECURITY_STATUS SEC_ENTRY ntlm_InitializeSecurityContextW( /* put the decoded client blob into the out buffer */ - if (fContextReq & ISC_REQ_ALLOCATE_MEMORY) + if (!pOutput || ((token_idx = ntlm_GetTokenBufferIndex(pOutput)) == -1)) { - if (pOutput) - { - pOutput->cBuffers = 1; - pOutput->pBuffers[0].pvBuffer = SECUR32_ALLOC(bin_len); - pOutput->pBuffers[0].cbBuffer = bin_len; - } + WARN("no SECBUFFER_TOKEN buffer could be found\n"); + ret = SEC_E_BUFFER_TOO_SMALL; + goto isc_end; } - if (!pOutput || !pOutput->cBuffers || pOutput->pBuffers[0].cbBuffer < bin_len) + if (fContextReq & ISC_REQ_ALLOCATE_MEMORY) + { + pOutput->pBuffers[token_idx].pvBuffer = SECUR32_ALLOC(bin_len); + pOutput->pBuffers[token_idx].cbBuffer = bin_len; + } + else if (pOutput->pBuffers[token_idx].cbBuffer < bin_len) { TRACE("out buffer is NULL or has not enough space\n"); ret = SEC_E_BUFFER_TOO_SMALL; goto isc_end; } - if (!pOutput->pBuffers[0].pvBuffer) + if (!pOutput->pBuffers[token_idx].pvBuffer) { TRACE("out buffer is NULL\n"); ret = SEC_E_INTERNAL_ERROR; goto isc_end; } - pOutput->pBuffers[0].cbBuffer = bin_len; - memcpy(pOutput->pBuffers[0].pvBuffer, bin, bin_len); + pOutput->pBuffers[token_idx].cbBuffer = bin_len; + memcpy(pOutput->pBuffers[token_idx].pvBuffer, bin, bin_len); if(ret == SEC_E_OK) { @@ -1237,26 +1262,6 @@ static SECURITY_STATUS SEC_ENTRY ntlm_RevertSecurityContext(PCtxtHandle phContex return ret; } -/************************************************************************* - * ntlm_GetTokenBufferIndex - * Calculates the index of the secbuffer with BufferType == SECBUFFER_TOKEN - * Returns index if found or -1 if not found. - */ -static int ntlm_GetTokenBufferIndex(PSecBufferDesc pMessage) -{ - UINT i; - - TRACE("%p\n", pMessage); - - for( i = 0; i < pMessage->cBuffers; ++i ) - { - if(pMessage->pBuffers[i].BufferType == SECBUFFER_TOKEN) - return i; - } - - return -1; -} - /*********************************************************************** * ntlm_CreateSignature * As both MakeSignature and VerifySignature need this, but different keys diff --git a/dlls/secur32/tests/ntlm.c b/dlls/secur32/tests/ntlm.c index e25749fd8b6..de6bd7f882f 100644 --- a/dlls/secur32/tests/ntlm.c +++ b/dlls/secur32/tests/ntlm.c @@ -448,6 +448,16 @@ static SECURITY_STATUS runClient(SspiData *sspi_data, BOOL first, ULONG data_rep ok(out_buf->pBuffers[0].cbBuffer == 0, "InitializeSecurityContext set buffer size to %lu\n", out_buf->pBuffers[0].cbBuffer); + + out_buf->pBuffers[0].cbBuffer = sspi_data->max_token; + out_buf->pBuffers[0].BufferType = SECBUFFER_DATA; + + ret = pInitializeSecurityContextA(sspi_data->cred, NULL, NULL, req_attr, + 0, data_rep, NULL, 0, sspi_data->ctxt, out_buf, + &ctxt_attr, &ttl); + + ok(ret == SEC_E_BUFFER_TOO_SMALL, "expected SEC_E_BUFFER_TOO_SMALL, got %s\n", getSecError(ret)); + out_buf->pBuffers[0].BufferType = SECBUFFER_TOKEN; } out_buf->pBuffers[0].cbBuffer = sspi_data->max_token;