diff --git a/dlls/crypt32/msg.c b/dlls/crypt32/msg.c index 4c495de44b0..001e1282fd9 100644 --- a/dlls/crypt32/msg.c +++ b/dlls/crypt32/msg.c @@ -38,13 +38,18 @@ typedef BOOL (*CryptMsgGetParamFunc)(HCRYPTMSG hCryptMsg, DWORD dwParamType, typedef BOOL (*CryptMsgUpdateFunc)(HCRYPTMSG hCryptMsg, const BYTE *pbData, DWORD cbData, BOOL fFinal); +typedef enum _CryptMsgState { + MsgStateInit, + MsgStateFinalized +} CryptMsgState; + typedef struct _CryptMsgBase { LONG ref; DWORD open_flags; BOOL streamed; CMSG_STREAM_INFO stream_info; - BOOL finalized; + CryptMsgState state; CryptMsgCloseFunc close; CryptMsgUpdateFunc update; CryptMsgGetParamFunc get_param; @@ -69,7 +74,7 @@ static inline void CryptMsgBase_Init(CryptMsgBase *msg, DWORD dwFlags, msg->close = close; msg->get_param = get_param; msg->update = update; - msg->finalized = FALSE; + msg->state = MsgStateInit; } typedef struct _CDataEncodeMsg @@ -161,12 +166,8 @@ static BOOL CDataEncodeMsg_Update(HCRYPTMSG hCryptMsg, const BYTE *pbData, CDataEncodeMsg *msg = (CDataEncodeMsg *)hCryptMsg; BOOL ret = FALSE; - if (msg->base.finalized) - SetLastError(CRYPT_E_MSG_ERROR); - else if (msg->base.streamed) + if (msg->base.streamed) { - if (fFinal) - msg->base.finalized = TRUE; __TRY { if (!msg->begun) @@ -223,7 +224,6 @@ static BOOL CDataEncodeMsg_Update(HCRYPTMSG hCryptMsg, const BYTE *pbData, } else { - msg->base.finalized = TRUE; if (!cbData) SetLastError(E_INVALIDARG); else @@ -380,7 +380,7 @@ static BOOL CRYPT_EncodePKCSDigestedData(CHashEncodeMsg *msg, void *pvData, items[cItem].encodeFunc = CRYPT_AsnEncodePKCSContentInfoInternal; cItem++; - if (msg->base.finalized) + if (msg->base.state == MsgStateFinalized) { size = sizeof(DWORD); ret = CryptGetHashParam(msg->hash, HP_HASHSIZE, @@ -454,7 +454,7 @@ static BOOL CHashEncodeMsg_GetParam(HCRYPTMSG hCryptMsg, DWORD dwParamType, 0); break; case CMSG_VERSION_PARAM: - if (!msg->base.finalized) + if (msg->base.state != MsgStateFinalized) SetLastError(CRYPT_E_MSG_ERROR); else { @@ -479,39 +479,31 @@ static BOOL CHashEncodeMsg_Update(HCRYPTMSG hCryptMsg, const BYTE *pbData, TRACE("(%p, %p, %d, %d)\n", hCryptMsg, pbData, cbData, fFinal); - if (msg->base.finalized) - SetLastError(CRYPT_E_MSG_ERROR); + msg->begun = TRUE; + if (msg->base.streamed || (msg->base.open_flags & CMSG_DETACHED_FLAG)) + { + /* Doesn't do much, as stream output is never called, and you + * can't get the content. + */ + ret = CryptHashData(msg->hash, pbData, cbData, 0); + } else { - msg->begun = TRUE; - if (fFinal) - msg->base.finalized = TRUE; - if (msg->base.streamed || (msg->base.open_flags & CMSG_DETACHED_FLAG)) - { - /* Doesn't do much, as stream output is never called, and you - * can't get the content. - */ - ret = CryptHashData(msg->hash, pbData, cbData, 0); - } + if (!fFinal) + SetLastError(CRYPT_E_MSG_ERROR); else { - if (!fFinal) - SetLastError(CRYPT_E_MSG_ERROR); - else + ret = CryptHashData(msg->hash, pbData, cbData, 0); + if (ret) { - ret = CryptHashData(msg->hash, pbData, cbData, 0); - if (ret) + msg->data.pbData = CryptMemAlloc(cbData); + if (msg->data.pbData) { - msg->data.pbData = CryptMemAlloc(cbData); - if (msg->data.pbData) - { - memcpy(msg->data.pbData + msg->data.cbData, pbData, - cbData); - msg->data.cbData += cbData; - } - else - ret = FALSE; + memcpy(msg->data.pbData + msg->data.cbData, pbData, cbData); + msg->data.cbData += cbData; } + else + ret = FALSE; } } } @@ -717,9 +709,19 @@ BOOL WINAPI CryptMsgUpdate(HCRYPTMSG hCryptMsg, const BYTE *pbData, DWORD cbData, BOOL fFinal) { CryptMsgBase *msg = (CryptMsgBase *)hCryptMsg; + BOOL ret = FALSE; TRACE("(%p, %p, %d, %d)\n", hCryptMsg, pbData, cbData, fFinal); - return msg->update(hCryptMsg, pbData, cbData, fFinal); + + if (msg->state == MsgStateFinalized) + SetLastError(CRYPT_E_MSG_ERROR); + else + { + ret = msg->update(hCryptMsg, pbData, cbData, fFinal); + if (fFinal) + msg->state = MsgStateFinalized; + } + return ret; } BOOL WINAPI CryptMsgGetParam(HCRYPTMSG hCryptMsg, DWORD dwParamType, diff --git a/dlls/crypt32/tests/msg.c b/dlls/crypt32/tests/msg.c index 688ae831dd0..a41a31076e9 100644 --- a/dlls/crypt32/tests/msg.c +++ b/dlls/crypt32/tests/msg.c @@ -994,7 +994,6 @@ static void test_decode_msg_update(void) /* Can't update after a final update */ SetLastError(0xdeadbeef); ret = CryptMsgUpdate(msg, dataEmptyContent, sizeof(dataEmptyContent), TRUE); - todo_wine ok(!ret && GetLastError() == CRYPT_E_MSG_ERROR, "Expected CRYPT_E_MSG_ERROR, got %x\n", GetLastError()); CryptMsgClose(msg);