diff --git a/dlls/d3drm/d3drm.c b/dlls/d3drm/d3drm.c index 7b74cfbc354..9be960713b4 100644 --- a/dlls/d3drm/d3drm.c +++ b/dlls/d3drm/d3drm.c @@ -55,7 +55,7 @@ struct d3drm IDirect3DRM IDirect3DRM_iface; IDirect3DRM2 IDirect3DRM2_iface; IDirect3DRM3 IDirect3DRM3_iface; - LONG ref; + LONG ref1, ref2, ref3, iface_count; }; static inline struct d3drm *impl_from_IDirect3DRM(IDirect3DRM *iface) @@ -73,6 +73,12 @@ static inline struct d3drm *impl_from_IDirect3DRM3(IDirect3DRM3 *iface) return CONTAINING_RECORD(iface, struct d3drm, IDirect3DRM3_iface); } +static void d3drm_destroy(struct d3drm *d3drm) +{ + HeapFree(GetProcessHeap(), 0, d3drm); + TRACE("d3drm object %p is being destroyed.\n", d3drm); +} + static HRESULT WINAPI d3drm1_QueryInterface(IDirect3DRM *iface, REFIID riid, void **out) { struct d3drm *d3drm = impl_from_IDirect3DRM(iface); @@ -106,22 +112,25 @@ static HRESULT WINAPI d3drm1_QueryInterface(IDirect3DRM *iface, REFIID riid, voi static ULONG WINAPI d3drm1_AddRef(IDirect3DRM *iface) { struct d3drm *d3drm = impl_from_IDirect3DRM(iface); - ULONG refcount = InterlockedIncrement(&d3drm->ref); + ULONG refcount = InterlockedIncrement(&d3drm->ref1); TRACE("%p increasing refcount to %u.\n", iface, refcount); + if (refcount == 1) + InterlockedIncrement(&d3drm->iface_count); + return refcount; } static ULONG WINAPI d3drm1_Release(IDirect3DRM *iface) { struct d3drm *d3drm = impl_from_IDirect3DRM(iface); - ULONG refcount = InterlockedDecrement(&d3drm->ref); + ULONG refcount = InterlockedDecrement(&d3drm->ref1); TRACE("%p decreasing refcount to %u.\n", iface, refcount); - if (!refcount) - HeapFree(GetProcessHeap(), 0, d3drm); + if (!refcount && !InterlockedDecrement(&d3drm->iface_count)) + d3drm_destroy(d3drm); return refcount; } @@ -455,15 +464,27 @@ static HRESULT WINAPI d3drm2_QueryInterface(IDirect3DRM2 *iface, REFIID riid, vo static ULONG WINAPI d3drm2_AddRef(IDirect3DRM2 *iface) { struct d3drm *d3drm = impl_from_IDirect3DRM2(iface); + ULONG refcount = InterlockedIncrement(&d3drm->ref2); - return d3drm1_AddRef(&d3drm->IDirect3DRM_iface); + TRACE("%p increasing refcount to %u.\n", iface, refcount); + + if (refcount == 1) + InterlockedIncrement(&d3drm->iface_count); + + return refcount; } static ULONG WINAPI d3drm2_Release(IDirect3DRM2 *iface) { struct d3drm *d3drm = impl_from_IDirect3DRM2(iface); + ULONG refcount = InterlockedDecrement(&d3drm->ref2); - return d3drm1_Release(&d3drm->IDirect3DRM_iface); + TRACE("%p decreasing refcount to %u.\n", iface, refcount); + + if (!refcount && !InterlockedDecrement(&d3drm->iface_count)) + d3drm_destroy(d3drm); + + return refcount; } static HRESULT WINAPI d3drm2_CreateObject(IDirect3DRM2 *iface, @@ -804,15 +825,27 @@ static HRESULT WINAPI d3drm3_QueryInterface(IDirect3DRM3 *iface, REFIID riid, vo static ULONG WINAPI d3drm3_AddRef(IDirect3DRM3 *iface) { struct d3drm *d3drm = impl_from_IDirect3DRM3(iface); + ULONG refcount = InterlockedIncrement(&d3drm->ref3); - return d3drm1_AddRef(&d3drm->IDirect3DRM_iface); + TRACE("%p increasing refcount to %u.\n", iface, refcount); + + if (refcount == 1) + InterlockedIncrement(&d3drm->iface_count); + + return refcount; } static ULONG WINAPI d3drm3_Release(IDirect3DRM3 *iface) { struct d3drm *d3drm = impl_from_IDirect3DRM3(iface); + ULONG refcount = InterlockedDecrement(&d3drm->ref3); - return d3drm1_Release(&d3drm->IDirect3DRM_iface); + TRACE("%p decreasing refcount to %u.\n", iface, refcount); + + if (!refcount && !InterlockedDecrement(&d3drm->iface_count)) + d3drm_destroy(d3drm); + + return refcount; } static HRESULT WINAPI d3drm3_CreateObject(IDirect3DRM3 *iface, @@ -1488,7 +1521,8 @@ HRESULT WINAPI Direct3DRMCreate(IDirect3DRM **d3drm) object->IDirect3DRM_iface.lpVtbl = &d3drm1_vtbl; object->IDirect3DRM2_iface.lpVtbl = &d3drm2_vtbl; object->IDirect3DRM3_iface.lpVtbl = &d3drm3_vtbl; - object->ref = 1; + object->ref1 = 1; + object->iface_count = 1; *d3drm = &object->IDirect3DRM_iface; diff --git a/dlls/d3drm/tests/d3drm.c b/dlls/d3drm/tests/d3drm.c index 977c9c0eb12..61f2cab18ea 100644 --- a/dlls/d3drm/tests/d3drm.c +++ b/dlls/d3drm/tests/d3drm.c @@ -1761,8 +1761,8 @@ static void test_d3drm_qi(void) { static const struct qi_test tests[] = { - { &IID_IDirect3DRM3, &IID_IDirect3DRM3, S_OK, TRUE }, - { &IID_IDirect3DRM2, &IID_IDirect3DRM2, S_OK, TRUE }, + { &IID_IDirect3DRM3, &IID_IDirect3DRM3, S_OK, FALSE }, + { &IID_IDirect3DRM2, &IID_IDirect3DRM2, S_OK, FALSE }, { &IID_IDirect3DRM, &IID_IDirect3DRM, S_OK, FALSE }, { &IID_IDirect3DRMDevice, NULL, CLASS_E_CLASSNOTAVAILABLE, FALSE }, { &IID_IDirect3DRMObject, NULL, CLASS_E_CLASSNOTAVAILABLE, FALSE },