diff --git a/dlls/ntdll/tests/sync.c b/dlls/ntdll/tests/sync.c index 7e389c9762b..f2499ea648a 100644 --- a/dlls/ntdll/tests/sync.c +++ b/dlls/ntdll/tests/sync.c @@ -44,7 +44,12 @@ static NTSTATUS (WINAPI *pNtReleaseSemaphore)( HANDLE, ULONG, ULONG * ); static NTSTATUS (WINAPI *pNtResetEvent)( HANDLE, LONG * ); static NTSTATUS (WINAPI *pNtSetEvent)( HANDLE, LONG * ); static NTSTATUS (WINAPI *pNtWaitForKeyedEvent)( HANDLE, const void *, BOOLEAN, const LARGE_INTEGER * ); +static BOOLEAN (WINAPI *pRtlAcquireResourceExclusive)( RTL_RWLOCK *, BOOLEAN ); +static BOOLEAN (WINAPI *pRtlAcquireResourceShared)( RTL_RWLOCK *, BOOLEAN ); +static void (WINAPI *pRtlDeleteResource)( RTL_RWLOCK * ); +static void (WINAPI *pRtlInitializeResource)( RTL_RWLOCK * ); static void (WINAPI *pRtlInitUnicodeString)( UNICODE_STRING *, const WCHAR * ); +static void (WINAPI *pRtlReleaseResource)( RTL_RWLOCK * ); static NTSTATUS (WINAPI *pRtlWaitOnAddress)( const void *, const void *, SIZE_T, const LARGE_INTEGER * ); static void (WINAPI *pRtlWakeAddressAll)( const void * ); static void (WINAPI *pRtlWakeAddressSingle)( const void * ); @@ -596,6 +601,159 @@ static void test_wait_on_address(void) ok(address == 0, "got %s\n", wine_dbgstr_longlong(address)); } +static HANDLE thread_ready, thread_done; + +static DWORD WINAPI resource_shared_thread(void *arg) +{ + RTL_RWLOCK *resource = arg; + BOOLEAN ret; + + ret = pRtlAcquireResourceShared(resource, TRUE); + ok(ret == TRUE, "got %u\n", ret); + + SetEvent(thread_ready); + ok(!WaitForSingleObject(thread_done, 1000), "wait failed\n"); + pRtlReleaseResource(resource); + return 0; +} + +static DWORD WINAPI resource_exclusive_thread(void *arg) +{ + RTL_RWLOCK *resource = arg; + BOOLEAN ret; + + ret = pRtlAcquireResourceExclusive(resource, TRUE); + ok(ret == TRUE, "got %u\n", ret); + + SetEvent(thread_ready); + ok(!WaitForSingleObject(thread_done, 1000), "wait failed\n"); + pRtlReleaseResource(resource); + return 0; +} + +static void test_resource(void) +{ + HANDLE thread, thread2; + RTL_RWLOCK resource; + BOOLEAN ret; + + pRtlInitializeResource(&resource); + thread_ready = CreateEventA(NULL, FALSE, FALSE, NULL); + thread_done = CreateEventA(NULL, FALSE, FALSE, NULL); + + ret = pRtlAcquireResourceShared(&resource, FALSE); + ok(ret == TRUE, "got %u\n", ret); + ret = pRtlAcquireResourceShared(&resource, FALSE); + ok(ret == TRUE, "got %u\n", ret); + ret = pRtlAcquireResourceExclusive(&resource, FALSE); + ok(ret == FALSE, "got %u\n", ret); + pRtlReleaseResource(&resource); + pRtlReleaseResource(&resource); + + ret = pRtlAcquireResourceExclusive(&resource, FALSE); + ok(ret == TRUE, "got %u\n", ret); + ret = pRtlAcquireResourceExclusive(&resource, FALSE); + ok(ret == TRUE, "got %u\n", ret); + ret = pRtlAcquireResourceShared(&resource, FALSE); + ok(ret == TRUE, "got %u\n", ret); + pRtlReleaseResource(&resource); + pRtlReleaseResource(&resource); + pRtlReleaseResource(&resource); + + /* Do not acquire the resource ourselves, but spawn a shared thread holding it. */ + + thread = CreateThread(NULL, 0, resource_shared_thread, &resource, 0, NULL); + ok(!WaitForSingleObject(thread_ready, 1000), "wait failed\n"); + + ret = pRtlAcquireResourceExclusive(&resource, FALSE); + ok(ret == FALSE, "got %u\n", ret); + ret = pRtlAcquireResourceShared(&resource, FALSE); + ok(ret == TRUE, "got %u\n", ret); + pRtlReleaseResource(&resource); + + SetEvent(thread_done); + ok(!WaitForSingleObject(thread, 1000), "wait failed\n"); + CloseHandle(thread); + + ret = pRtlAcquireResourceExclusive(&resource, FALSE); + ok(ret == TRUE, "got %u\n", ret); + pRtlReleaseResource(&resource); + ret = pRtlAcquireResourceShared(&resource, FALSE); + ok(ret == TRUE, "got %u\n", ret); + pRtlReleaseResource(&resource); + + /* Acquire the resource as exclusive, and then spawn a shared thread. */ + + ret = pRtlAcquireResourceExclusive(&resource, FALSE); + ok(ret == TRUE, "got %u\n", ret); + thread = CreateThread(NULL, 0, resource_shared_thread, &resource, 0, NULL); + ok(WaitForSingleObject(thread_ready, 100) == WAIT_TIMEOUT, "expected timeout\n"); + + ret = pRtlAcquireResourceExclusive(&resource, FALSE); + ok(ret == TRUE, "got %u\n", ret); + pRtlReleaseResource(&resource); + ret = pRtlAcquireResourceShared(&resource, FALSE); + ok(ret == TRUE, "got %u\n", ret); + pRtlReleaseResource(&resource); + + pRtlReleaseResource(&resource); + ok(!WaitForSingleObject(thread_ready, 1000), "wait failed\n"); + SetEvent(thread_done); + ok(!WaitForSingleObject(thread, 1000), "wait failed\n"); + CloseHandle(thread); + + /* Acquire the resource as shared, and then spawn an exclusive thread. */ + + ret = pRtlAcquireResourceShared(&resource, FALSE); + ok(ret == TRUE, "got %u\n", ret); + thread = CreateThread(NULL, 0, resource_exclusive_thread, &resource, 0, NULL); + ok(WaitForSingleObject(thread_ready, 100) == WAIT_TIMEOUT, "expected timeout\n"); + + ret = pRtlAcquireResourceExclusive(&resource, FALSE); + ok(ret == FALSE, "got %u\n", ret); + ret = pRtlAcquireResourceShared(&resource, FALSE); + ok(ret == TRUE, "got %u\n", ret); + pRtlReleaseResource(&resource); + + pRtlReleaseResource(&resource); + ok(!WaitForSingleObject(thread_ready, 1000), "wait failed\n"); + SetEvent(thread_done); + ok(!WaitForSingleObject(thread, 1000), "wait failed\n"); + CloseHandle(thread); + + /* Spawn a shared and then exclusive waiter. */ + thread = CreateThread(NULL, 0, resource_shared_thread, &resource, 0, NULL); + ok(!WaitForSingleObject(thread_ready, 1000), "wait failed\n"); + thread2 = CreateThread(NULL, 0, resource_exclusive_thread, &resource, 0, NULL); + ok(WaitForSingleObject(thread_ready, 100) == WAIT_TIMEOUT, "expected timeout\n"); + + ret = pRtlAcquireResourceExclusive(&resource, FALSE); + ok(ret == FALSE, "got %u\n", ret); + ret = pRtlAcquireResourceShared(&resource, FALSE); + ok(ret == TRUE, "got %u\n", ret); + pRtlReleaseResource(&resource); + + SetEvent(thread_done); + ok(!WaitForSingleObject(thread, 1000), "wait failed\n"); + CloseHandle(thread); + + ok(!WaitForSingleObject(thread_ready, 1000), "wait failed\n"); + SetEvent(thread_done); + ok(!WaitForSingleObject(thread2, 1000), "wait failed\n"); + CloseHandle(thread2); + + ret = pRtlAcquireResourceExclusive(&resource, FALSE); + ok(ret == TRUE, "got %u\n", ret); + pRtlReleaseResource(&resource); + ret = pRtlAcquireResourceShared(&resource, FALSE); + ok(ret == TRUE, "got %u\n", ret); + pRtlReleaseResource(&resource); + + CloseHandle(thread_ready); + CloseHandle(thread_done); + pRtlDeleteResource(&resource); +} + START_TEST(sync) { HMODULE module = GetModuleHandleA("ntdll.dll"); @@ -618,7 +776,12 @@ START_TEST(sync) pNtResetEvent = (void *)GetProcAddress(module, "NtResetEvent"); pNtSetEvent = (void *)GetProcAddress(module, "NtSetEvent"); pNtWaitForKeyedEvent = (void *)GetProcAddress(module, "NtWaitForKeyedEvent"); + pRtlAcquireResourceExclusive = (void *)GetProcAddress(module, "RtlAcquireResourceExclusive"); + pRtlAcquireResourceShared = (void *)GetProcAddress(module, "RtlAcquireResourceShared"); + pRtlDeleteResource = (void *)GetProcAddress(module, "RtlDeleteResource"); + pRtlInitializeResource = (void *)GetProcAddress(module, "RtlInitializeResource"); pRtlInitUnicodeString = (void *)GetProcAddress(module, "RtlInitUnicodeString"); + pRtlReleaseResource = (void *)GetProcAddress(module, "RtlReleaseResource"); pRtlWaitOnAddress = (void *)GetProcAddress(module, "RtlWaitOnAddress"); pRtlWakeAddressAll = (void *)GetProcAddress(module, "RtlWakeAddressAll"); pRtlWakeAddressSingle = (void *)GetProcAddress(module, "RtlWakeAddressSingle"); @@ -628,4 +791,5 @@ START_TEST(sync) test_mutant(); test_semaphore(); test_keyed_events(); + test_resource(); }