ole32: Fix refcounting of IObjContext per-thread instance.

Signed-off-by: Nikolay Sivov <nsivov@codeweavers.com>
Signed-off-by: Huw Davies <huw@codeweavers.com>
Signed-off-by: Alexandre Julliard <julliard@winehq.org>
This commit is contained in:
Nikolay Sivov 2016-03-14 20:30:54 +03:00 committed by Alexandre Julliard
parent 9a28e35064
commit cf218bca42
2 changed files with 55 additions and 52 deletions

View File

@ -4683,10 +4683,15 @@ static ULONG Context_AddRef(Context *This)
static ULONG Context_Release(Context *This) static ULONG Context_Release(Context *This)
{ {
ULONG refs = InterlockedDecrement(&This->refs); /* Context instance is initially created with CoGetContextToken() with refcount set to 0,
if (!refs) releasing context while refcount is at 0 destroys it. */
if (!This->refs)
{
HeapFree(GetProcessHeap(), 0, This); HeapFree(GetProcessHeap(), 0, This);
return refs; return 0;
}
return InterlockedDecrement(&This->refs);
} }
static HRESULT WINAPI Context_CTI_QueryInterface(IComThreadingInfo *iface, REFIID riid, LPVOID *ppv) static HRESULT WINAPI Context_CTI_QueryInterface(IComThreadingInfo *iface, REFIID riid, LPVOID *ppv)
@ -4924,39 +4929,19 @@ static const IObjContextVtbl Context_Object_Vtbl =
*/ */
HRESULT WINAPI CoGetObjectContext(REFIID riid, void **ppv) HRESULT WINAPI CoGetObjectContext(REFIID riid, void **ppv)
{ {
APARTMENT *apt = COM_CurrentApt(); IObjContext *context;
Context *context;
HRESULT hr; HRESULT hr;
TRACE("(%s, %p)\n", debugstr_guid(riid), ppv); TRACE("(%s, %p)\n", debugstr_guid(riid), ppv);
*ppv = NULL; *ppv = NULL;
if (!apt) hr = CoGetContextToken((ULONG_PTR*)&context);
{ if (FAILED(hr))
if (!(apt = apartment_find_multi_threaded())) return hr;
{
ERR("apartment not initialised\n");
return CO_E_NOTINITIALIZED;
}
apartment_release(apt);
}
context = HeapAlloc(GetProcessHeap(), 0, sizeof(*context)); return IObjContext_QueryInterface(context, riid, ppv);
if (!context)
return E_OUTOFMEMORY;
context->IComThreadingInfo_iface.lpVtbl = &Context_Threading_Vtbl;
context->IContextCallback_iface.lpVtbl = &Context_Callback_Vtbl;
context->IObjContext_iface.lpVtbl = &Context_Object_Vtbl;
context->refs = 1;
hr = IComThreadingInfo_QueryInterface(&context->IComThreadingInfo_iface, riid, ppv);
IComThreadingInfo_Release(&context->IComThreadingInfo_iface);
return hr;
} }
/*********************************************************************** /***********************************************************************
* CoGetContextToken [OLE32.@] * CoGetContextToken [OLE32.@]
*/ */
@ -4985,16 +4970,24 @@ HRESULT WINAPI CoGetContextToken( ULONG_PTR *token )
if (!info->context_token) if (!info->context_token)
{ {
HRESULT hr; Context *context;
IObjContext *ctx;
hr = CoGetObjectContext(&IID_IObjContext, (void **)&ctx); context = HeapAlloc(GetProcessHeap(), 0, sizeof(*context));
if (FAILED(hr)) return hr; if (!context)
info->context_token = ctx; return E_OUTOFMEMORY;
context->IComThreadingInfo_iface.lpVtbl = &Context_Threading_Vtbl;
context->IContextCallback_iface.lpVtbl = &Context_Callback_Vtbl;
context->IObjContext_iface.lpVtbl = &Context_Object_Vtbl;
/* Context token does not take a reference, it's always zero until
interface is explicitely requested with CoGetObjectContext(). */
context->refs = 0;
info->context_token = &context->IObjContext_iface;
} }
*token = (ULONG_PTR)info->context_token; *token = (ULONG_PTR)info->context_token;
TRACE("apt->context_token=%p\n", info->context_token); TRACE("context_token=%p\n", info->context_token);
return S_OK; return S_OK;
} }

View File

@ -1774,7 +1774,7 @@ static void test_CoGetObjectContext(void)
{ {
HRESULT hr; HRESULT hr;
ULONG refs; ULONG refs;
IComThreadingInfo *pComThreadingInfo; IComThreadingInfo *pComThreadingInfo, *threadinginfo2;
IContextCallback *pContextCallback; IContextCallback *pContextCallback;
IObjContext *pObjContext; IObjContext *pObjContext;
APTTYPE apttype; APTTYPE apttype;
@ -1786,7 +1786,7 @@ static void test_CoGetObjectContext(void)
if (!pCoGetObjectContext) if (!pCoGetObjectContext)
{ {
skip("CoGetObjectContext not present\n"); win_skip("CoGetObjectContext not present\n");
return; return;
} }
@ -1812,6 +1812,12 @@ static void test_CoGetObjectContext(void)
hr = pCoGetObjectContext(&IID_IComThreadingInfo, (void **)&pComThreadingInfo); hr = pCoGetObjectContext(&IID_IComThreadingInfo, (void **)&pComThreadingInfo);
ok(hr == S_OK, "Expected S_OK, got 0x%08x\n", hr); ok(hr == S_OK, "Expected S_OK, got 0x%08x\n", hr);
threadinginfo2 = NULL;
hr = pCoGetObjectContext(&IID_IComThreadingInfo, (void **)&threadinginfo2);
ok(hr == S_OK, "Expected S_OK, got 0x%08x\n", hr);
ok(pComThreadingInfo == threadinginfo2, "got different instance\n");
IComThreadingInfo_Release(threadinginfo2);
hr = IComThreadingInfo_GetCurrentLogicalThreadId(pComThreadingInfo, NULL); hr = IComThreadingInfo_GetCurrentLogicalThreadId(pComThreadingInfo, NULL);
ok(hr == E_INVALIDARG, "got 0x%08x\n", hr); ok(hr == E_INVALIDARG, "got 0x%08x\n", hr);
@ -1854,11 +1860,8 @@ static void test_CoGetObjectContext(void)
hr = pCoGetObjectContext(&IID_IContextCallback, (void **)&pContextCallback); hr = pCoGetObjectContext(&IID_IContextCallback, (void **)&pContextCallback);
ok_ole_success(hr, "CoGetObjectContext(ContextCallback)"); ok_ole_success(hr, "CoGetObjectContext(ContextCallback)");
if (hr == S_OK) refs = IContextCallback_Release(pContextCallback);
{ ok(refs == 0, "pContextCallback should have 0 refs instead of %d refs\n", refs);
refs = IContextCallback_Release(pContextCallback);
ok(refs == 0, "pContextCallback should have 0 refs instead of %d refs\n", refs);
}
CoUninitialize(); CoUninitialize();
@ -1881,11 +1884,8 @@ static void test_CoGetObjectContext(void)
hr = pCoGetObjectContext(&IID_IContextCallback, (void **)&pContextCallback); hr = pCoGetObjectContext(&IID_IContextCallback, (void **)&pContextCallback);
ok_ole_success(hr, "CoGetObjectContext(ContextCallback)"); ok_ole_success(hr, "CoGetObjectContext(ContextCallback)");
if (hr == S_OK) refs = IContextCallback_Release(pContextCallback);
{ ok(refs == 0, "pContextCallback should have 0 refs instead of %d refs\n", refs);
refs = IContextCallback_Release(pContextCallback);
ok(refs == 0, "pContextCallback should have 0 refs instead of %d refs\n", refs);
}
hr = pCoGetObjectContext(&IID_IObjContext, (void **)&pObjContext); hr = pCoGetObjectContext(&IID_IObjContext, (void **)&pObjContext);
ok_ole_success(hr, "CoGetObjectContext"); ok_ole_success(hr, "CoGetObjectContext");
@ -2007,7 +2007,7 @@ static void test_CoGetContextToken(void)
{ {
HRESULT hr; HRESULT hr;
ULONG refs; ULONG refs;
ULONG_PTR token; ULONG_PTR token, token2;
IObjContext *ctx; IObjContext *ctx;
struct info info; struct info info;
HANDLE thread; HANDLE thread;
@ -2042,6 +2042,11 @@ static void test_CoGetContextToken(void)
hr = pCoGetContextToken(&token); hr = pCoGetContextToken(&token);
ok(hr == S_OK, "Expected S_OK, got 0x%08x\n", hr); ok(hr == S_OK, "Expected S_OK, got 0x%08x\n", hr);
token2 = 0;
hr = pCoGetContextToken(&token2);
ok(hr == S_OK, "Expected S_OK, got 0x%08x\n", hr);
ok(token == token2, "got different token\n");
SetEvent(info.stop); SetEvent(info.stop);
ok( !WaitForSingleObject(thread, 10000), "wait timed out\n" ); ok( !WaitForSingleObject(thread, 10000), "wait timed out\n" );
@ -2063,18 +2068,23 @@ static void test_CoGetContextToken(void)
ok(hr == S_OK, "Expected S_OK, got 0x%08x\n", hr); ok(hr == S_OK, "Expected S_OK, got 0x%08x\n", hr);
ok(token, "Expected token != 0\n"); ok(token, "Expected token != 0\n");
token2 = 0;
hr = pCoGetContextToken(&token2);
ok(hr == S_OK, "Expected S_OK, got 0x%08x\n", hr);
ok(token2 == token, "got different token\n");
refs = IUnknown_AddRef((IUnknown *)token); refs = IUnknown_AddRef((IUnknown *)token);
todo_wine ok(refs == 1, "Expected 1, got %u\n", refs); ok(refs == 1, "Expected 1, got %u\n", refs);
hr = pCoGetObjectContext(&IID_IObjContext, (void **)&ctx); hr = pCoGetObjectContext(&IID_IObjContext, (void **)&ctx);
ok(hr == S_OK, "Expected S_OK, got 0x%08x\n", hr); ok(hr == S_OK, "Expected S_OK, got 0x%08x\n", hr);
todo_wine ok(ctx == (IObjContext *)token, "Expected interface pointers to be the same\n"); ok(ctx == (IObjContext *)token, "Expected interface pointers to be the same\n");
refs = IObjContext_AddRef(ctx); refs = IObjContext_AddRef(ctx);
todo_wine ok(refs == 3, "Expected 3, got %u\n", refs); ok(refs == 3, "Expected 3, got %u\n", refs);
refs = IObjContext_Release(ctx); refs = IObjContext_Release(ctx);
todo_wine ok(refs == 2, "Expected 2, got %u\n", refs); ok(refs == 2, "Expected 2, got %u\n", refs);
refs = IUnknown_Release((IUnknown *)token); refs = IUnknown_Release((IUnknown *)token);
ok(refs == 1, "Expected 1, got %u\n", refs); ok(refs == 1, "Expected 1, got %u\n", refs);
@ -2084,7 +2094,7 @@ static void test_CoGetContextToken(void)
hr = pCoGetContextToken(&token); hr = pCoGetContextToken(&token);
ok(hr == S_OK, "Expected S_OK, got 0x%08x\n", hr); ok(hr == S_OK, "Expected S_OK, got 0x%08x\n", hr);
ok(token, "Expected token != 0\n"); ok(token, "Expected token != 0\n");
todo_wine ok(ctx == (IObjContext *)token, "Expected interface pointers to be the same\n"); ok(ctx == (IObjContext *)token, "Expected interface pointers to be the same\n");
refs = IObjContext_AddRef(ctx); refs = IObjContext_AddRef(ctx);
ok(refs == 2, "Expected 1, got %u\n", refs); ok(refs == 2, "Expected 1, got %u\n", refs);