diff --git a/dlls/itss/itss.c b/dlls/itss/itss.c index f47a0e96e7c..1d236a59604 100644 --- a/dlls/itss/itss.c +++ b/dlls/itss/itss.c @@ -105,20 +105,31 @@ static ULONG WINAPI ITSSCF_Release(LPCLASSFACTORY iface) } -static HRESULT WINAPI ITSSCF_CreateInstance(LPCLASSFACTORY iface, LPUNKNOWN pOuter, - REFIID riid, LPVOID *ppobj) +static HRESULT WINAPI ITSSCF_CreateInstance(IClassFactory *iface, IUnknown *outer, + REFIID riid, void **ppv) { IClassFactoryImpl *This = impl_from_IClassFactory(iface); + IUnknown *unk; HRESULT hres; - LPUNKNOWN punk; - TRACE("(%p)->(%p,%s,%p)\n", This, pOuter, debugstr_guid(riid), ppobj); + TRACE("(%p)->(%p %s %p)\n", This, outer, debugstr_guid(riid), ppv); - *ppobj = NULL; - hres = This->pfnCreateInstance(pOuter, (LPVOID *) &punk); - if (SUCCEEDED(hres)) { - hres = IUnknown_QueryInterface(punk, riid, ppobj); - IUnknown_Release(punk); + if(outer && !IsEqualGUID(riid, &IID_IUnknown)) { + *ppv = NULL; + return CLASS_E_NOAGGREGATION; + } + + hres = This->pfnCreateInstance(outer, (void**)&unk); + if(FAILED(hres)) { + *ppv = NULL; + return hres; + } + + if(!IsEqualGUID(riid, &IID_IUnknown)) { + hres = IUnknown_QueryInterface(unk, riid, ppv); + IUnknown_Release(unk); + }else { + *ppv = unk; } return hres; } diff --git a/dlls/itss/protocol.c b/dlls/itss/protocol.c index 1463518c2cf..1cdb3650027 100644 --- a/dlls/itss/protocol.c +++ b/dlls/itss/protocol.c @@ -36,16 +36,23 @@ WINE_DEFAULT_DEBUG_CHANNEL(itss); typedef struct { + IUnknown IUnknown_inner; IInternetProtocol IInternetProtocol_iface; IInternetProtocolInfo IInternetProtocolInfo_iface; LONG ref; + IUnknown *outer; ULONG offset; struct chmFile *chm_file; struct chmUnitInfo chm_object; } ITSProtocol; +static inline ITSProtocol *impl_from_IUnknown(IUnknown *iface) +{ + return CONTAINING_RECORD(iface, ITSProtocol, IUnknown_inner); +} + static inline ITSProtocol *impl_from_IInternetProtocol(IInternetProtocol *iface) { return CONTAINING_RECORD(iface, ITSProtocol, IInternetProtocol_iface); @@ -65,14 +72,13 @@ static void release_chm(ITSProtocol *This) This->offset = 0; } -static HRESULT WINAPI ITSProtocol_QueryInterface(IInternetProtocol *iface, REFIID riid, void **ppv) +static HRESULT WINAPI ITSProtocol_QueryInterface(IUnknown *iface, REFIID riid, void **ppv) { - ITSProtocol *This = impl_from_IInternetProtocol(iface); + ITSProtocol *This = impl_from_IUnknown(iface); - *ppv = NULL; if(IsEqualGUID(&IID_IUnknown, riid)) { TRACE("(%p)->(IID_IUnknown %p)\n", This, ppv); - *ppv = &This->IInternetProtocol_iface; + *ppv = &This->IUnknown_inner; }else if(IsEqualGUID(&IID_IInternetProtocolRoot, riid)) { TRACE("(%p)->(IID_IInternetProtocolRoot %p)\n", This, ppv); *ppv = &This->IInternetProtocol_iface; @@ -82,28 +88,27 @@ static HRESULT WINAPI ITSProtocol_QueryInterface(IInternetProtocol *iface, REFII }else if(IsEqualGUID(&IID_IInternetProtocolInfo, riid)) { TRACE("(%p)->(IID_IInternetProtocolInfo %p)\n", This, ppv); *ppv = &This->IInternetProtocolInfo_iface; + }else { + *ppv = NULL; + WARN("not supported interface %s\n", debugstr_guid(riid)); + return E_NOINTERFACE; } - if(*ppv) { - IInternetProtocol_AddRef(iface); - return S_OK; - } - - WARN("not supported interface %s\n", debugstr_guid(riid)); - return E_NOINTERFACE; + IUnknown_AddRef((IUnknown*)*ppv); + return S_OK; } -static ULONG WINAPI ITSProtocol_AddRef(IInternetProtocol *iface) +static ULONG WINAPI ITSProtocol_AddRef(IUnknown *iface) { - ITSProtocol *This = impl_from_IInternetProtocol(iface); + ITSProtocol *This = impl_from_IUnknown(iface); LONG ref = InterlockedIncrement(&This->ref); TRACE("(%p) ref=%d\n", This, ref); return ref; } -static ULONG WINAPI ITSProtocol_Release(IInternetProtocol *iface) +static ULONG WINAPI ITSProtocol_Release(IUnknown *iface) { - ITSProtocol *This = impl_from_IInternetProtocol(iface); + ITSProtocol *This = impl_from_IUnknown(iface); LONG ref = InterlockedDecrement(&This->ref); TRACE("(%p) ref=%d\n", This, ref); @@ -118,6 +123,30 @@ static ULONG WINAPI ITSProtocol_Release(IInternetProtocol *iface) return ref; } +static const IUnknownVtbl ITSProtocolUnkVtbl = { + ITSProtocol_QueryInterface, + ITSProtocol_AddRef, + ITSProtocol_Release +}; + +static HRESULT WINAPI ITSInternetProtocol_QueryInterface(IInternetProtocol *iface, REFIID riid, void **ppv) +{ + ITSProtocol *This = impl_from_IInternetProtocol(iface); + return IUnknown_QueryInterface(This->outer, riid, ppv); +} + +static ULONG WINAPI ITSInternetProtocol_AddRef(IInternetProtocol *iface) +{ + ITSProtocol *This = impl_from_IInternetProtocol(iface); + return IUnknown_AddRef(This->outer); +} + +static ULONG WINAPI ITSInternetProtocol_Release(IInternetProtocol *iface) +{ + ITSProtocol *This = impl_from_IInternetProtocol(iface); + return IUnknown_Release(This->outer); +} + static LPCWSTR skip_schema(LPCWSTR url) { static const WCHAR its_schema[] = {'i','t','s',':'}; @@ -387,9 +416,9 @@ static HRESULT WINAPI ITSProtocol_UnlockRequest(IInternetProtocol *iface) } static const IInternetProtocolVtbl ITSProtocolVtbl = { - ITSProtocol_QueryInterface, - ITSProtocol_AddRef, - ITSProtocol_Release, + ITSInternetProtocol_QueryInterface, + ITSInternetProtocol_AddRef, + ITSInternetProtocol_Release, ITSProtocol_Start, ITSProtocol_Continue, ITSProtocol_Abort, @@ -520,21 +549,24 @@ static const IInternetProtocolInfoVtbl ITSProtocolInfoVtbl = { ITSProtocolInfo_QueryInfo }; -HRESULT ITSProtocol_create(IUnknown *pUnkOuter, LPVOID *ppobj) +HRESULT ITSProtocol_create(IUnknown *outer, void **ppv) { ITSProtocol *ret; - TRACE("(%p %p)\n", pUnkOuter, ppobj); + TRACE("(%p %p)\n", outer, ppv); ITSS_LockModule(); ret = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(ITSProtocol)); + if(!ret) + return E_OUTOFMEMORY; + ret->IUnknown_inner.lpVtbl = &ITSProtocolUnkVtbl; ret->IInternetProtocol_iface.lpVtbl = &ITSProtocolVtbl; ret->IInternetProtocolInfo_iface.lpVtbl = &ITSProtocolInfoVtbl; ret->ref = 1; + ret->outer = outer ? outer : &ret->IUnknown_inner; - *ppobj = &ret->IInternetProtocol_iface; - + *ppv = &ret->IUnknown_inner; return S_OK; } diff --git a/dlls/itss/tests/protocol.c b/dlls/itss/tests/protocol.c index 663411f2e92..3028da43374 100644 --- a/dlls/itss/tests/protocol.c +++ b/dlls/itss/tests/protocol.c @@ -60,6 +60,7 @@ DEFINE_EXPECT(ReportProgress_CACHEFILENAMEAVAILABLE); DEFINE_EXPECT(ReportProgress_DIRECTBIND); DEFINE_EXPECT(ReportData); DEFINE_EXPECT(ReportResult); +DEFINE_EXPECT(outer_QI_test); static HRESULT expect_hrResult; static IInternetProtocol *read_protocol = NULL; @@ -660,6 +661,68 @@ static void delete_chm(void) ok(ret, "DeleteFileA failed: %d\n", GetLastError()); } +static const IID outer_test_iid = {0xabcabc00,0,0,{0,0,0,0,0,0,0,0x66}}; + +static HRESULT WINAPI outer_QueryInterface(IUnknown *iface, REFIID riid, void **ppv) +{ + if(IsEqualGUID(riid, &outer_test_iid)) { + CHECK_EXPECT(outer_QI_test); + *ppv = (IUnknown*)0xdeadbeef; + return S_OK; + } + ok(0, "unexpected call %s\n", wine_dbgstr_guid(riid)); + return E_NOINTERFACE; +} + +static ULONG WINAPI outer_AddRef(IUnknown *iface) +{ + return 2; +} + +static ULONG WINAPI outer_Release(IUnknown *iface) +{ + return 1; +} + +static const IUnknownVtbl outer_vtbl = { + outer_QueryInterface, + outer_AddRef, + outer_Release +}; + +static void test_com_aggregation(const CLSID *clsid) +{ + IUnknown outer = { &outer_vtbl }; + IClassFactory *class_factory; + IUnknown *unk, *unk2, *unk3; + HRESULT hres; + + hres = CoGetClassObject(clsid, CLSCTX_INPROC_SERVER, NULL, &IID_IClassFactory, (void**)&class_factory); + ok(hres == S_OK, "CoGetClassObject failed: %08x\n", hres); + + hres = IClassFactory_CreateInstance(class_factory, &outer, &IID_IUnknown, (void**)&unk); + ok(hres == S_OK, "CreateInstance returned: %08x\n", hres); + + hres = IUnknown_QueryInterface(unk, &IID_IInternetProtocol, (void**)&unk2); + ok(hres == S_OK, "Could not get IInternetProtocol iface: %08x\n", hres); + + SET_EXPECT(outer_QI_test); + hres = IUnknown_QueryInterface(unk2, &outer_test_iid, (void**)&unk3); + CHECK_CALLED(outer_QI_test); + ok(hres == S_OK, "Could not get IInternetProtocol iface: %08x\n", hres); + ok(unk3 == (IUnknown*)0xdeadbeef, "unexpected unk2\n"); + + IUnknown_Release(unk2); + IUnknown_Release(unk); + + unk = (void*)0xdeadbeef; + hres = IClassFactory_CreateInstance(class_factory, &outer, &IID_IInternetProtocol, (void**)&unk); + ok(hres == CLASS_E_NOAGGREGATION, "CreateInstance returned: %08x\n", hres); + ok(!unk, "unk = %p\n", unk); + + IClassFactory_Release(class_factory); +} + START_TEST(protocol) { OleInitialize(NULL); @@ -669,6 +732,7 @@ START_TEST(protocol) test_its_protocol(); test_mk_protocol(); + test_com_aggregation(&CLSID_ITSProtocol); delete_chm(); OleUninitialize();