diff --git a/dlls/bcrypt/tests/bcrypt.c b/dlls/bcrypt/tests/bcrypt.c index 456727d04a9..fb5ac03b039 100644 --- a/dlls/bcrypt/tests/bcrypt.c +++ b/dlls/bcrypt/tests/bcrypt.c @@ -862,6 +862,64 @@ static void test_BCryptGenerateSymmetricKey(void) ok(ret == STATUS_SUCCESS, "got %08x\n", ret); } +#define RACE_TEST_COUNT 200 +static LONG encrypt_race_start_barrier; + +static DWORD WINAPI encrypt_race_thread(void *parameter) +{ + static UCHAR nonce[] = + {0x11,0x20,0x30,0x40,0x50,0x60,0x10,0x20,0x30,0x40,0x50,0x60}; + static UCHAR auth_data[] = + {0x61,0x50,0x40,0x30,0x20,0x10,0x60,0x50,0x40,0x30,0x20,0x10}; + static UCHAR data2[] = + {0x00,0x01,0x02,0x03,0x04,0x05,0x06,0x07,0x08,0x09,0x0a,0x0b,0x0c,0x0d,0x0e,0x0f, + 0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10,0x10}; + static UCHAR expected4[] = + {0xb2,0x27,0x19,0x09,0xc7,0x89,0xdc,0x52,0x24,0x83,0x3a,0x55,0x34,0x76,0x2c,0xbf, + 0x15,0xa1,0xcb,0x40,0x78,0x11,0xba,0xbc,0xa4,0x76,0x69,0x7c,0x75,0x4f,0x11,0xba}; + static UCHAR expected_tag3[] = + {0xef,0xee,0x75,0x99,0xb8,0x12,0xe9,0xf0,0xb4,0xcc,0x65,0x11,0x67,0x60,0x2d,0xe6}; + + BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO auth_info; + BCRYPT_KEY_HANDLE key = parameter; + UCHAR ciphertext[48], tag[16]; + unsigned int i, test; + NTSTATUS ret; + ULONG size; + + memset(&auth_info, 0, sizeof(auth_info)); + auth_info.cbSize = sizeof(auth_info); + auth_info.dwInfoVersion = 1; + auth_info.pbNonce = nonce; + auth_info.cbNonce = sizeof(nonce); + auth_info.pbTag = tag; + auth_info.cbTag = sizeof(tag); + auth_info.pbAuthData = auth_data; + auth_info.cbAuthData = sizeof(auth_data); + + InterlockedIncrement(&encrypt_race_start_barrier); + while (InterlockedCompareExchange(&encrypt_race_start_barrier, 3, 2) != 2) + ; + + for (test = 0; test < RACE_TEST_COUNT; ++test) + { + size = 0; + memset(ciphertext, 0xff, sizeof(ciphertext)); + memset(tag, 0xff, sizeof(tag)); + ret = pBCryptEncrypt(key, data2, 32, &auth_info, NULL, 0, ciphertext, 32, &size, 0); + ok(ret == STATUS_SUCCESS, "got %08x\n", ret); + ok(size == 32, "got %u\n", size); + ok(!memcmp(ciphertext, expected4, sizeof(expected4)), "wrong data\n"); + ok(!memcmp(tag, expected_tag3, sizeof(expected_tag3)), "wrong tag\n"); + for (i = 0; i < 32; i++) + ok(ciphertext[i] == expected4[i], "%u: %02x != %02x\n", i, ciphertext[i], expected4[i]); + for (i = 0; i < 16; i++) + ok(tag[i] == expected_tag3[i], "%u: %02x != %02x\n", i, tag[i], expected_tag3[i]); + } + + return 0; +} + static void test_BCryptEncrypt(void) { static UCHAR nonce[] = @@ -921,9 +979,10 @@ static void test_BCryptEncrypt(void) BCRYPT_AUTHENTICATED_CIPHER_MODE_INFO auth_info; UCHAR *buf, ciphertext[48], ivbuf[16], tag[16]; BCRYPT_AUTH_TAG_LENGTHS_STRUCT tag_length; + ULONG size, len, i, test; BCRYPT_ALG_HANDLE aes; BCRYPT_KEY_HANDLE key; - ULONG size, len, i; + HANDLE hthread; NTSTATUS ret; ret = pBCryptOpenAlgorithmProvider(&aes, BCRYPT_AES_ALGORITHM, NULL, 0); @@ -1215,6 +1274,32 @@ static void test_BCryptEncrypt(void) ret = pBCryptEncrypt(key, data2, 32, &auth_info, ivbuf, 16, ciphertext, 48, &size, BCRYPT_BLOCK_PADDING); ok(ret == STATUS_INVALID_PARAMETER, "got %08x\n", ret); + /* race test */ + + encrypt_race_start_barrier = 0; + hthread = CreateThread(NULL, 0, encrypt_race_thread, key, 0, NULL); + + while (InterlockedCompareExchange(&encrypt_race_start_barrier, 2, 1) != 1) + ; + + for (test = 0; test < RACE_TEST_COUNT; ++test) + { + size = 0; + memset(ciphertext, 0xff, sizeof(ciphertext)); + memset(tag, 0xff, sizeof(tag)); + ret = pBCryptEncrypt(key, data2, 32, &auth_info, NULL, 0, ciphertext, 32, &size, 0); + ok(ret == STATUS_SUCCESS, "got %08x\n", ret); + ok(size == 32, "got %u\n", size); + ok(!memcmp(ciphertext, expected4, sizeof(expected4)), "wrong data\n"); + ok(!memcmp(tag, expected_tag3, sizeof(expected_tag2)), "wrong tag\n"); + for (i = 0; i < 32; i++) + ok(ciphertext[i] == expected4[i], "%u: %02x != %02x\n", i, ciphertext[i], expected4[i]); + for (i = 0; i < 16; i++) + ok(tag[i] == expected_tag3[i], "%u: %02x != %02x\n", i, tag[i], expected_tag3[i]); + } + + WaitForSingleObject(hthread, INFINITE); + ret = pBCryptDestroyKey(key); ok(ret == STATUS_SUCCESS, "got %08x\n", ret); HeapFree(GetProcessHeap(), 0, buf);