From 74989b6e4166a4f171672375c2e19092e8068764 Mon Sep 17 00:00:00 2001 From: Jacek Caban Date: Tue, 8 May 2018 18:18:27 +0200 Subject: [PATCH] urlmon: Added support for COM aggregation of file protocol handler. Signed-off-by: Jacek Caban Signed-off-by: Alexandre Julliard --- dlls/urlmon/file.c | 59 ++++++++++++++++++++++++++------ dlls/urlmon/tests/protocol.c | 65 ++++++++++++++++++++++++++++++++++++ dlls/urlmon/urlmon_main.c | 28 +++++++++++----- 3 files changed, 133 insertions(+), 19 deletions(-) diff --git a/dlls/urlmon/file.c b/dlls/urlmon/file.c index bba7be1079d..9e7f3cf184e 100644 --- a/dlls/urlmon/file.c +++ b/dlls/urlmon/file.c @@ -25,9 +25,12 @@ WINE_DEFAULT_DEBUG_CHANNEL(urlmon); typedef struct { + IUnknown IUnknown_outer; IInternetProtocolEx IInternetProtocolEx_iface; IInternetPriority IInternetPriority_iface; + IUnknown *outer; + HANDLE file; ULONG size; LONG priority; @@ -35,6 +38,11 @@ typedef struct { LONG ref; } FileProtocol; +static inline FileProtocol *impl_from_IUnknown(IUnknown *iface) +{ + return CONTAINING_RECORD(iface, FileProtocol, IUnknown_outer); +} + static inline FileProtocol *impl_from_IInternetProtocolEx(IInternetProtocolEx *iface) { return CONTAINING_RECORD(iface, FileProtocol, IInternetProtocolEx_iface); @@ -45,14 +53,14 @@ static inline FileProtocol *impl_from_IInternetPriority(IInternetPriority *iface return CONTAINING_RECORD(iface, FileProtocol, IInternetPriority_iface); } -static HRESULT WINAPI FileProtocol_QueryInterface(IInternetProtocolEx *iface, REFIID riid, void **ppv) +static HRESULT WINAPI FileProtocolUnk_QueryInterface(IUnknown *iface, REFIID riid, void **ppv) { - FileProtocol *This = impl_from_IInternetProtocolEx(iface); + FileProtocol *This = impl_from_IUnknown(iface); *ppv = NULL; if(IsEqualGUID(&IID_IUnknown, riid)) { TRACE("(%p)->(IID_IUnknown %p)\n", This, ppv); - *ppv = &This->IInternetProtocolEx_iface; + *ppv = &This->IUnknown_outer; }else if(IsEqualGUID(&IID_IInternetProtocolRoot, riid)) { TRACE("(%p)->(IID_IInternetProtocolRoot %p)\n", This, ppv); *ppv = &This->IInternetProtocolEx_iface; @@ -68,7 +76,7 @@ static HRESULT WINAPI FileProtocol_QueryInterface(IInternetProtocolEx *iface, RE } if(*ppv) { - IInternetProtocolEx_AddRef(iface); + IUnknown_AddRef((IUnknown*)*ppv); return S_OK; } @@ -76,17 +84,17 @@ static HRESULT WINAPI FileProtocol_QueryInterface(IInternetProtocolEx *iface, RE return E_NOINTERFACE; } -static ULONG WINAPI FileProtocol_AddRef(IInternetProtocolEx *iface) +static ULONG WINAPI FileProtocolUnk_AddRef(IUnknown *iface) { - FileProtocol *This = impl_from_IInternetProtocolEx(iface); + FileProtocol *This = impl_from_IUnknown(iface); LONG ref = InterlockedIncrement(&This->ref); TRACE("(%p) ref=%d\n", This, ref); return ref; } -static ULONG WINAPI FileProtocol_Release(IInternetProtocolEx *iface) +static ULONG WINAPI FileProtocolUnk_Release(IUnknown *iface) { - FileProtocol *This = impl_from_IInternetProtocolEx(iface); + FileProtocol *This = impl_from_IUnknown(iface); LONG ref = InterlockedDecrement(&This->ref); TRACE("(%p) ref=%d\n", This, ref); @@ -102,6 +110,33 @@ static ULONG WINAPI FileProtocol_Release(IInternetProtocolEx *iface) return ref; } +static const IUnknownVtbl FileProtocolUnkVtbl = { + FileProtocolUnk_QueryInterface, + FileProtocolUnk_AddRef, + FileProtocolUnk_Release +}; + +static HRESULT WINAPI FileProtocol_QueryInterface(IInternetProtocolEx *iface, REFIID riid, void **ppv) +{ + FileProtocol *This = impl_from_IInternetProtocolEx(iface); + TRACE("(%p)->(%s %p)\n", This, debugstr_guid(riid), ppv); + return IUnknown_QueryInterface(This->outer, riid, ppv); +} + +static ULONG WINAPI FileProtocol_AddRef(IInternetProtocolEx *iface) +{ + FileProtocol *This = impl_from_IInternetProtocolEx(iface); + TRACE("(%p)\n", This); + return IUnknown_AddRef(This->outer); +} + +static ULONG WINAPI FileProtocol_Release(IInternetProtocolEx *iface) +{ + FileProtocol *This = impl_from_IInternetProtocolEx(iface); + TRACE("(%p)\n", This); + return IUnknown_Release(This->outer); +} + static HRESULT WINAPI FileProtocol_Start(IInternetProtocolEx *iface, LPCWSTR szUrl, IInternetProtocolSink *pOIProtSink, IInternetBindInfo *pOIBindInfo, DWORD grfPI, HANDLE_PTR dwReserved) @@ -383,22 +418,24 @@ static const IInternetPriorityVtbl FilePriorityVtbl = { FilePriority_GetPriority }; -HRESULT FileProtocol_Construct(IUnknown *pUnkOuter, LPVOID *ppobj) +HRESULT FileProtocol_Construct(IUnknown *outer, LPVOID *ppobj) { FileProtocol *ret; - TRACE("(%p %p)\n", pUnkOuter, ppobj); + TRACE("(%p %p)\n", outer, ppobj); URLMON_LockModule(); ret = heap_alloc(sizeof(FileProtocol)); + ret->IUnknown_outer.lpVtbl = &FileProtocolUnkVtbl; ret->IInternetProtocolEx_iface.lpVtbl = &FileProtocolExVtbl; ret->IInternetPriority_iface.lpVtbl = &FilePriorityVtbl; ret->file = INVALID_HANDLE_VALUE; ret->priority = 0; ret->ref = 1; + ret->outer = outer ? outer : (IUnknown*)&ret->IUnknown_outer; - *ppobj = &ret->IInternetProtocolEx_iface; + *ppobj = &ret->IUnknown_outer; return S_OK; } diff --git a/dlls/urlmon/tests/protocol.c b/dlls/urlmon/tests/protocol.c index c153119ce43..f3710fca835 100644 --- a/dlls/urlmon/tests/protocol.c +++ b/dlls/urlmon/tests/protocol.c @@ -125,6 +125,7 @@ DEFINE_EXPECT(MimeFilter_Continue); DEFINE_EXPECT(Stream_Seek); DEFINE_EXPECT(Stream_Read); DEFINE_EXPECT(Redirect); +DEFINE_EXPECT(outer_QI_test); static const WCHAR wszIndexHtml[] = {'i','n','d','e','x','.','h','t','m','l',0}; static const WCHAR index_url[] = @@ -3964,6 +3965,68 @@ static void test_binding(int prot, DWORD grf_pi, DWORD test_flags) IInternetSession_Release(session); } +static const IID outer_test_iid = {0xabcabc00,0,0,{0,0,0,0,0,0,0,0x66}}; + +static HRESULT WINAPI outer_QueryInterface(IUnknown *iface, REFIID riid, void **ppv) +{ + if(IsEqualGUID(riid, &outer_test_iid)) { + CHECK_EXPECT(outer_QI_test); + *ppv = (IUnknown*)0xdeadbeef; + return S_OK; + } + ok(0, "unexpected call %s\n", wine_dbgstr_guid(riid)); + return E_NOINTERFACE; +} + +static ULONG WINAPI outer_AddRef(IUnknown *iface) +{ + return 2; +} + +static ULONG WINAPI outer_Release(IUnknown *iface) +{ + return 1; +} + +static const IUnknownVtbl outer_vtbl = { + outer_QueryInterface, + outer_AddRef, + outer_Release +}; + +static void test_com_aggregation(const CLSID *clsid) +{ + IUnknown outer = { &outer_vtbl }; + IClassFactory *class_factory; + IUnknown *unk, *unk2, *unk3; + HRESULT hres; + + hres = CoGetClassObject(clsid, CLSCTX_INPROC_SERVER, NULL, &IID_IClassFactory, (void**)&class_factory); + ok(hres == S_OK, "CoGetClassObject failed: %08x\n", hres); + + hres = IClassFactory_CreateInstance(class_factory, &outer, &IID_IUnknown, (void**)&unk); + ok(hres == S_OK, "CreateInstance returned: %08x\n", hres); + + hres = IUnknown_QueryInterface(unk, &IID_IInternetProtocol, (void**)&unk2); + ok(hres == S_OK, "Could not get IDispatch iface: %08x\n", hres); + + SET_EXPECT(outer_QI_test); + hres = IUnknown_QueryInterface(unk2, &outer_test_iid, (void**)&unk3); + CHECK_CALLED(outer_QI_test); + ok(hres == S_OK, "Could not get IInternetProtocol iface: %08x\n", hres); + ok(unk3 == (IUnknown*)0xdeadbeef, "unexpected unk2\n"); + + IUnknown_Release(unk2); + IUnknown_Release(unk); + + unk = (void*)0xdeadbeef; + hres = IClassFactory_CreateInstance(class_factory, &outer, &IID_IInternetProtocol, (void**)&unk); + ok(hres == CLASS_E_NOAGGREGATION, "CreateInstance returned: %08x\n", hres); + ok(!unk, "unk = %p\n", unk); + + IClassFactory_Release(class_factory); +} + START_TEST(protocol) { HMODULE hurlmon; @@ -4037,5 +4100,7 @@ START_TEST(protocol) CloseHandle(event_continue); CloseHandle(event_continue_done); + test_com_aggregation(&CLSID_FileProtocol); + OleUninitialize(); } diff --git a/dlls/urlmon/urlmon_main.c b/dlls/urlmon/urlmon_main.c index c4f1cec5fa0..5968a6ce55f 100644 --- a/dlls/urlmon/urlmon_main.c +++ b/dlls/urlmon/urlmon_main.c @@ -303,19 +303,31 @@ static ULONG WINAPI CF_Release(IClassFactory *iface) } -static HRESULT WINAPI CF_CreateInstance(IClassFactory *iface, IUnknown *pOuter, - REFIID riid, LPVOID *ppobj) +static HRESULT WINAPI CF_CreateInstance(IClassFactory *iface, IUnknown *outer, + REFIID riid, void **ppv) { ClassFactory *This = impl_from_IClassFactory(iface); + IUnknown *unk; HRESULT hres; - LPUNKNOWN punk; - TRACE("(%p)->(%p,%s,%p)\n",This,pOuter,debugstr_guid(riid),ppobj); + TRACE("(%p)->(%p %s %p)\n", This, outer, debugstr_guid(riid), ppv); - *ppobj = NULL; - if(SUCCEEDED(hres = This->pfnCreateInstance(pOuter, (LPVOID *) &punk))) { - hres = IUnknown_QueryInterface(punk, riid, ppobj); - IUnknown_Release(punk); + if(outer && !IsEqualGUID(riid, &IID_IUnknown)) { + *ppv = NULL; + return CLASS_E_NOAGGREGATION; + } + + hres = This->pfnCreateInstance(outer, (void**)&unk); + if(FAILED(hres)) { + *ppv = NULL; + return hres; + } + + if(!IsEqualGUID(riid, &IID_IUnknown)) { + hres = IUnknown_QueryInterface(unk, riid, ppv); + IUnknown_Release(unk); + }else { + *ppv = unk; } return hres; }