diff --git a/dlls/inetcomm/inetcomm_main.c b/dlls/inetcomm/inetcomm_main.c index 4cb55226dc0..75436286631 100644 --- a/dlls/inetcomm/inetcomm_main.c +++ b/dlls/inetcomm/inetcomm_main.c @@ -117,13 +117,20 @@ static HRESULT WINAPI cf_CreateInstance( IClassFactory *iface, LPUNKNOWN pOuter, *ppobj = NULL; + if (pOuter && !IsEqualGUID(&IID_IUnknown, riid)) + return CLASS_E_NOAGGREGATION; + r = This->create_object( pOuter, (LPVOID*) &punk ); if (FAILED(r)) return r; + if (IsEqualGUID(&IID_IUnknown, riid)) { + *ppobj = punk; + return S_OK; + } + r = IUnknown_QueryInterface( punk, riid, ppobj ); IUnknown_Release( punk ); - return r; } diff --git a/dlls/inetcomm/protocol.c b/dlls/inetcomm/protocol.c index d560b90b092..f6600b9774b 100644 --- a/dlls/inetcomm/protocol.c +++ b/dlls/inetcomm/protocol.c @@ -28,31 +28,34 @@ WINE_DEFAULT_DEBUG_CHANNEL(inetcomm); typedef struct { + IUnknown IUnknown_inner; IInternetProtocol IInternetProtocol_iface; IInternetProtocolInfo IInternetProtocolInfo_iface; + LONG ref; + IUnknown *outer_unk; } MimeHtmlProtocol; -static inline MimeHtmlProtocol *impl_from_IInternetProtocol(IInternetProtocol *iface) +static inline MimeHtmlProtocol *impl_from_IUnknown(IUnknown *iface) { - return CONTAINING_RECORD(iface, MimeHtmlProtocol, IInternetProtocol_iface); + return CONTAINING_RECORD(iface, MimeHtmlProtocol, IUnknown_inner); } -static HRESULT WINAPI MimeHtmlProtocol_QueryInterface(IInternetProtocol *iface, REFIID riid, void **ppv) +static HRESULT WINAPI MimeHtmlProtocol_QueryInterface(IUnknown *iface, REFIID riid, void **ppv) { - MimeHtmlProtocol *This = impl_from_IInternetProtocol(iface); + MimeHtmlProtocol *This = impl_from_IUnknown(iface); if(IsEqualGUID(&IID_IUnknown, riid)) { - TRACE("(%p)->(IID_IUnknown %p)\n", iface, ppv); + TRACE("(%p)->(IID_IUnknown %p)\n", This, ppv); *ppv = &This->IInternetProtocol_iface; }else if(IsEqualGUID(&IID_IInternetProtocolRoot, riid)) { - TRACE("(%p)->(IID_IInternetProtocolRoot %p)\n", iface, ppv); + TRACE("(%p)->(IID_IInternetProtocolRoot %p)\n", This, ppv); *ppv = &This->IInternetProtocol_iface; }else if(IsEqualGUID(&IID_IInternetProtocol, riid)) { - TRACE("(%p)->(IID_IInternetProtocol %p)\n", iface, ppv); + TRACE("(%p)->(IID_IInternetProtocol %p)\n", This, ppv); *ppv = &This->IInternetProtocol_iface; }else if(IsEqualGUID(&IID_IInternetProtocolInfo, riid)) { - TRACE("(%p)->(IID_IInternetProtocolInfo %p)\n", iface, ppv); + TRACE("(%p)->(IID_IInternetProtocolInfo %p)\n", This, ppv); *ppv = &This->IInternetProtocolInfo_iface; }else { FIXME("unknown interface %s\n", debugstr_guid(riid)); @@ -64,9 +67,9 @@ static HRESULT WINAPI MimeHtmlProtocol_QueryInterface(IInternetProtocol *iface, return S_OK; } -static ULONG WINAPI MimeHtmlProtocol_AddRef(IInternetProtocol *iface) +static ULONG WINAPI MimeHtmlProtocol_AddRef(IUnknown *iface) { - MimeHtmlProtocol *This = impl_from_IInternetProtocol(iface); + MimeHtmlProtocol *This = impl_from_IUnknown(iface); ULONG ref = InterlockedIncrement(&This->ref); TRACE("(%p) ref=%d\n", This, ref); @@ -74,9 +77,9 @@ static ULONG WINAPI MimeHtmlProtocol_AddRef(IInternetProtocol *iface) return ref; } -static ULONG WINAPI MimeHtmlProtocol_Release(IInternetProtocol *iface) +static ULONG WINAPI MimeHtmlProtocol_Release(IUnknown *iface) { - MimeHtmlProtocol *This = impl_from_IInternetProtocol(iface); + MimeHtmlProtocol *This = impl_from_IUnknown(iface); ULONG ref = InterlockedDecrement(&This->ref); TRACE("(%p) ref=%x\n", This, ref); @@ -87,6 +90,35 @@ static ULONG WINAPI MimeHtmlProtocol_Release(IInternetProtocol *iface) return ref; } +static const IUnknownVtbl MimeHtmlProtocolInnerVtbl = { + MimeHtmlProtocol_QueryInterface, + MimeHtmlProtocol_AddRef, + MimeHtmlProtocol_Release +}; + +static inline MimeHtmlProtocol *impl_from_IInternetProtocol(IInternetProtocol *iface) +{ + return CONTAINING_RECORD(iface, MimeHtmlProtocol, IInternetProtocol_iface); +} + +static HRESULT WINAPI InternetProtocol_QueryInterface(IInternetProtocol *iface, REFIID riid, void **ppv) +{ + MimeHtmlProtocol *This = impl_from_IInternetProtocol(iface); + return IUnknown_QueryInterface(This->outer_unk, riid, ppv); +} + +static ULONG WINAPI InternetProtocol_AddRef(IInternetProtocol *iface) +{ + MimeHtmlProtocol *This = impl_from_IInternetProtocol(iface); + return IUnknown_AddRef(This->outer_unk); +} + +static ULONG WINAPI InternetProtocol_Release(IInternetProtocol *iface) +{ + MimeHtmlProtocol *This = impl_from_IInternetProtocol(iface); + return IUnknown_Release(This->outer_unk); +} + static HRESULT WINAPI MimeHtmlProtocol_Start(IInternetProtocol *iface, const WCHAR *szUrl, IInternetProtocolSink* pOIProtSink, IInternetBindInfo* pOIBindInfo, DWORD grfPI, HANDLE_PTR dwReserved) @@ -162,9 +194,9 @@ static HRESULT WINAPI MimeHtmlProtocol_UnlockRequest(IInternetProtocol *iface) } static const IInternetProtocolVtbl MimeHtmlProtocolVtbl = { - MimeHtmlProtocol_QueryInterface, - MimeHtmlProtocol_AddRef, - MimeHtmlProtocol_Release, + InternetProtocol_QueryInterface, + InternetProtocol_AddRef, + InternetProtocol_Release, MimeHtmlProtocol_Start, MimeHtmlProtocol_Continue, MimeHtmlProtocol_Abort, @@ -185,19 +217,19 @@ static inline MimeHtmlProtocol *impl_from_IInternetProtocolInfo(IInternetProtoco static HRESULT WINAPI MimeHtmlProtocolInfo_QueryInterface(IInternetProtocolInfo *iface, REFIID riid, void **ppv) { MimeHtmlProtocol *This = impl_from_IInternetProtocolInfo(iface); - return IInternetProtocol_QueryInterface(&This->IInternetProtocol_iface, riid, ppv); + return IUnknown_QueryInterface(This->outer_unk, riid, ppv); } static ULONG WINAPI MimeHtmlProtocolInfo_AddRef(IInternetProtocolInfo *iface) { MimeHtmlProtocol *This = impl_from_IInternetProtocolInfo(iface); - return IInternetProtocol_AddRef(&This->IInternetProtocol_iface); + return IUnknown_AddRef(This->outer_unk); } static ULONG WINAPI MimeHtmlProtocolInfo_Release(IInternetProtocolInfo *iface) { MimeHtmlProtocol *This = impl_from_IInternetProtocolInfo(iface); - return IInternetProtocol_Release(&This->IInternetProtocol_iface); + return IUnknown_Release(This->outer_unk); } static HRESULT WINAPI MimeHtmlProtocolInfo_ParseUrl(IInternetProtocolInfo *iface, LPCWSTR pwzUrl, @@ -253,17 +285,16 @@ HRESULT MimeHtmlProtocol_create(IUnknown *outer, void **obj) { MimeHtmlProtocol *protocol; - if(outer) - FIXME("outer not supported\n"); - protocol = heap_alloc(sizeof(*protocol)); if(!protocol) return E_OUTOFMEMORY; + protocol->IUnknown_inner.lpVtbl = &MimeHtmlProtocolInnerVtbl; protocol->IInternetProtocol_iface.lpVtbl = &MimeHtmlProtocolVtbl; protocol->IInternetProtocolInfo_iface.lpVtbl = &MimeHtmlProtocolInfoVtbl; protocol->ref = 1; + protocol->outer_unk = outer ? outer : &protocol->IUnknown_inner; - *obj = &protocol->IInternetProtocol_iface; + *obj = &protocol->IUnknown_inner; return S_OK; } diff --git a/dlls/inetcomm/tests/mimeole.c b/dlls/inetcomm/tests/mimeole.c index 7d68a253d8d..924e0e32847 100644 --- a/dlls/inetcomm/tests/mimeole.c +++ b/dlls/inetcomm/tests/mimeole.c @@ -772,8 +772,32 @@ static void test_mhtml_protocol_info(void) IInternetProtocolInfo_Release(protocol_info); } +static HRESULT WINAPI outer_QueryInterface(IUnknown *iface, REFIID riid, void **ppv) +{ + ok(0, "unexpected call\n"); + return E_NOINTERFACE; +} + +static ULONG WINAPI outer_AddRef(IUnknown *iface) +{ + return 2; +} + +static ULONG WINAPI outer_Release(IUnknown *iface) +{ + return 1; +} + +static const IUnknownVtbl outer_vtbl = { + outer_QueryInterface, + outer_AddRef, + outer_Release +}; + static void test_mhtml_protocol(void) { + IUnknown outer = { &outer_vtbl }; + IClassFactory *class_factory; IUnknown *unk, *unk2; HRESULT hres; @@ -781,15 +805,25 @@ static void test_mhtml_protocol(void) hres = CoGetClassObject(&CLSID_IMimeHtmlProtocol, CLSCTX_INPROC_SERVER, NULL, &IID_IUnknown, (void**)&unk); ok(hres == S_OK, "CoGetClassObject failed: %08x\n", hres); - hres = IUnknown_QueryInterface(unk, &IID_IClassFactory, (void**)&unk2); - ok(hres == S_OK, "Could not get IClassFactory iface: %08x\n", hres); - IUnknown_Release(unk2); - hres = IUnknown_QueryInterface(unk, &IID_IInternetProtocolInfo, (void**)&unk2); ok(hres == E_NOINTERFACE, "IInternetProtocolInfo supported\n"); + hres = IUnknown_QueryInterface(unk, &IID_IClassFactory, (void**)&class_factory); + ok(hres == S_OK, "Could not get IClassFactory iface: %08x\n", hres); IUnknown_Release(unk); + hres = IClassFactory_CreateInstance(class_factory, &outer, &IID_IUnknown, (void**)&unk); + ok(hres == S_OK, "CreateInstance returned: %08x\n", hres); + hres = IUnknown_QueryInterface(unk, &IID_IInternetProtocol, (void**)&unk2); + ok(hres == S_OK, "Coult not get IInternetProtocol iface: %08x\n", hres); + IUnknown_Release(unk2); + IUnknown_Release(unk); + + hres = IClassFactory_CreateInstance(class_factory, (IUnknown*)0xdeadbeef, &IID_IInternetProtocol, (void**)&unk2); + ok(hres == CLASS_E_NOAGGREGATION, "CreateInstance returned: %08x\n", hres); + + IClassFactory_Release(class_factory); + test_mhtml_protocol_info(); }