diff --git a/dlls/crypt32/decode.c b/dlls/crypt32/decode.c index 3a487eadfcd..465400812ef 100644 --- a/dlls/crypt32/decode.c +++ b/dlls/crypt32/decode.c @@ -96,6 +96,7 @@ static BOOL CRYPT_AsnDecodeBool(const BYTE *pbEncoded, DWORD cbEncoded, static BOOL CRYPT_AsnDecodeOctetsInternal(const BYTE *pbEncoded, DWORD cbEncoded, DWORD dwFlags, void *pvStructInfo, DWORD *pcbStructInfo, DWORD *pcbDecoded); +/* Doesn't check the tag, assumes the caller does so */ static BOOL CRYPT_AsnDecodeBitsInternal(const BYTE *pbEncoded, DWORD cbEncoded, DWORD dwFlags, void *pvStructInfo, DWORD *pcbStructInfo, DWORD *pcbDecoded); static BOOL CRYPT_AsnDecodeIntInternal(const BYTE *pbEncoded, DWORD cbEncoded, @@ -3298,64 +3299,54 @@ static BOOL CRYPT_AsnDecodeBitsInternal(const BYTE *pbEncoded, DWORD cbEncoded, DWORD dwFlags, void *pvStructInfo, DWORD *pcbStructInfo, DWORD *pcbDecoded) { BOOL ret; + DWORD bytesNeeded, dataLen; + BYTE lenBytes = GET_LEN_BYTES(pbEncoded[1]); TRACE("(%p, %d, 0x%08x, %p, %d, %p)\n", pbEncoded, cbEncoded, dwFlags, pvStructInfo, *pcbStructInfo, pcbDecoded); - if (pbEncoded[0] == ASN_BITSTRING) + if ((ret = CRYPT_GetLen(pbEncoded, cbEncoded, &dataLen))) { - DWORD bytesNeeded, dataLen; - BYTE lenBytes = GET_LEN_BYTES(pbEncoded[1]); - - if ((ret = CRYPT_GetLen(pbEncoded, cbEncoded, &dataLen))) + if (dwFlags & CRYPT_DECODE_NOCOPY_FLAG) + bytesNeeded = sizeof(CRYPT_BIT_BLOB); + else + bytesNeeded = dataLen - 1 + sizeof(CRYPT_BIT_BLOB); + if (pcbDecoded) + *pcbDecoded = 1 + lenBytes + dataLen; + if (!pvStructInfo) + *pcbStructInfo = bytesNeeded; + else if (*pcbStructInfo < bytesNeeded) { + *pcbStructInfo = bytesNeeded; + SetLastError(ERROR_MORE_DATA); + ret = FALSE; + } + else + { + CRYPT_BIT_BLOB *blob; + + *pcbStructInfo = bytesNeeded; + blob = (CRYPT_BIT_BLOB *)pvStructInfo; + blob->cbData = dataLen - 1; + blob->cUnusedBits = *(pbEncoded + 1 + lenBytes); if (dwFlags & CRYPT_DECODE_NOCOPY_FLAG) - bytesNeeded = sizeof(CRYPT_BIT_BLOB); - else - bytesNeeded = dataLen - 1 + sizeof(CRYPT_BIT_BLOB); - if (pcbDecoded) - *pcbDecoded = 1 + lenBytes + dataLen; - if (!pvStructInfo) - *pcbStructInfo = bytesNeeded; - else if (*pcbStructInfo < bytesNeeded) { - *pcbStructInfo = bytesNeeded; - SetLastError(ERROR_MORE_DATA); - ret = FALSE; + blob->pbData = (BYTE *)pbEncoded + 2 + lenBytes; } else { - CRYPT_BIT_BLOB *blob; - - *pcbStructInfo = bytesNeeded; - blob = (CRYPT_BIT_BLOB *)pvStructInfo; - blob->cbData = dataLen - 1; - blob->cUnusedBits = *(pbEncoded + 1 + lenBytes); - if (dwFlags & CRYPT_DECODE_NOCOPY_FLAG) + assert(blob->pbData); + if (blob->cbData) { - blob->pbData = (BYTE *)pbEncoded + 2 + lenBytes; - } - else - { - assert(blob->pbData); - if (blob->cbData) - { - BYTE mask = 0xff << blob->cUnusedBits; + BYTE mask = 0xff << blob->cUnusedBits; - memcpy(blob->pbData, pbEncoded + 2 + lenBytes, - blob->cbData); - blob->pbData[blob->cbData - 1] &= mask; - } + memcpy(blob->pbData, pbEncoded + 2 + lenBytes, + blob->cbData); + blob->pbData[blob->cbData - 1] &= mask; } } } } - else - { - SetLastError(CRYPT_E_ASN1_BADTAG); - ret = FALSE; - } - TRACE("returning %d (%08x)\n", ret, GetLastError()); return ret; } @@ -3372,7 +3363,17 @@ static BOOL WINAPI CRYPT_AsnDecodeBits(DWORD dwCertEncodingType, { DWORD bytesNeeded; - if ((ret = CRYPT_AsnDecodeBitsInternal(pbEncoded, cbEncoded, + if (!cbEncoded) + { + SetLastError(CRYPT_E_ASN1_CORRUPT); + ret = FALSE; + } + else if (pbEncoded[0] != ASN_BITSTRING) + { + SetLastError(CRYPT_E_ASN1_BADTAG); + ret = FALSE; + } + else if ((ret = CRYPT_AsnDecodeBitsInternal(pbEncoded, cbEncoded, dwFlags & ~CRYPT_DECODE_ALLOC_FLAG, NULL, &bytesNeeded, NULL))) { if (!pvStructInfo) diff --git a/dlls/crypt32/tests/encode.c b/dlls/crypt32/tests/encode.c index c269b9b6aeb..9a0eff5667d 100644 --- a/dlls/crypt32/tests/encode.c +++ b/dlls/crypt32/tests/encode.c @@ -3323,7 +3323,6 @@ static void test_decodeCRLDistPoints(DWORD dwEncoding) ret = pCryptDecodeObjectEx(dwEncoding, X509_CRL_DIST_POINTS, distPointWithReason, distPointWithReason[1] + 2, CRYPT_DECODE_ALLOC_FLAG, NULL, (BYTE *)&buf, &size); - todo_wine ok(ret, "CryptDecodeObjectEx failed: %08x\n", GetLastError()); if (ret) {