diff --git a/dlls/mshtml/protocol.c b/dlls/mshtml/protocol.c index 41b13190e21..739234292b8 100644 --- a/dlls/mshtml/protocol.c +++ b/dlls/mshtml/protocol.c @@ -41,6 +41,7 @@ WINE_DEFAULT_DEBUG_CHANNEL(mshtml); #define PROTOCOLINFO(x) ((IInternetProtocolInfo*) &(x)->lpInternetProtocolInfoVtbl) #define CLASSFACTORY(x) ((IClassFactory*) &(x)->lpClassFactoryVtbl) +#define PROTOCOL(x) ((IInternetProtocol*) &(x)->lpInternetProtocolVtbl) typedef struct { const IInternetProtocolInfoVtbl *lpInternetProtocolInfoVtbl; @@ -140,21 +141,27 @@ typedef struct { BYTE *data; ULONG data_len; ULONG cur; + + IUnknown *pUnkOuter; } AboutProtocol; static HRESULT WINAPI AboutProtocol_QueryInterface(IInternetProtocol *iface, REFIID riid, void **ppv) { + AboutProtocol *This = (AboutProtocol*)iface; + *ppv = NULL; if(IsEqualGUID(&IID_IUnknown, riid)) { TRACE("(%p)->(IID_IUnknown %p)\n", iface, ppv); - *ppv = iface; + if(This->pUnkOuter) + return IUnknown_QueryInterface(This->pUnkOuter, riid, ppv); + *ppv = PROTOCOL(This); }else if(IsEqualGUID(&IID_IInternetProtocolRoot, riid)) { TRACE("(%p)->(IID_IInternetProtocolRoot %p)\n", iface, ppv); - *ppv = iface; + *ppv = PROTOCOL(This); }else if(IsEqualGUID(&IID_IInternetProtocol, riid)) { TRACE("(%p)->(IID_IInternetProtocol %p)\n", iface, ppv); - *ppv = iface; + *ppv = PROTOCOL(This); }else if(IsEqualGUID(&IID_IServiceProvider, riid)) { FIXME("IServiceProvider is not implemented\n"); return E_NOINTERFACE; @@ -174,7 +181,7 @@ static ULONG WINAPI AboutProtocol_AddRef(IInternetProtocol *iface) AboutProtocol *This = (AboutProtocol*)iface; ULONG ref = InterlockedIncrement(&This->ref); TRACE("(%p) ref=%ld\n", iface, ref); - return ref; + return This->pUnkOuter ? IUnknown_AddRef(This->pUnkOuter) : ref; } static ULONG WINAPI AboutProtocol_Release(IInternetProtocol *iface) @@ -190,7 +197,7 @@ static ULONG WINAPI AboutProtocol_Release(IInternetProtocol *iface) UNLOCK_MODULE(); } - return ref; + return This->pUnkOuter ? IUnknown_Release(This->pUnkOuter) : ref; } static HRESULT WINAPI AboutProtocol_Start(IInternetProtocol *iface, LPCWSTR szUrl, @@ -352,9 +359,9 @@ static HRESULT WINAPI AboutProtocolFactory_CreateInstance(IClassFactory *iface, REFIID riid, void **ppv) { AboutProtocol *ret; - HRESULT hres; + HRESULT hres = S_OK; - TRACE("(%p)->(%s %p)\n", iface, debugstr_guid(riid), ppv); + TRACE("(%p)->(%p %s %p)\n", iface, pUnkOuter, debugstr_guid(riid), ppv); ret = HeapAlloc(GetProcessHeap(), 0, sizeof(AboutProtocol)); ret->lpInternetProtocolVtbl = &AboutProtocolVtbl; @@ -363,8 +370,17 @@ static HRESULT WINAPI AboutProtocolFactory_CreateInstance(IClassFactory *iface, ret->data = NULL; ret->data_len = 0; ret->cur = 0; + ret->pUnkOuter = pUnkOuter; - hres = IUnknown_QueryInterface((IUnknown*)ret, riid, ppv); + if(pUnkOuter) { + ret->ref = 1; + if(IsEqualGUID(&IID_IUnknown, riid)) + *ppv = PROTOCOL(ret); + else + hres = E_INVALIDARG; + }else { + hres = IInternetProtocol_QueryInterface(PROTOCOL(ret), riid, ppv); + } if(SUCCEEDED(hres)) LOCK_MODULE(); @@ -683,7 +699,7 @@ static HRESULT WINAPI ResProtocolFactory_CreateInstance(IClassFactory *iface, IU ResProtocol *ret; HRESULT hres; - TRACE("(%p)->(%s %p)\n", iface, debugstr_guid(riid), ppv); + TRACE("(%p)->(%p %s %p)\n", iface, pUnkOuter, debugstr_guid(riid), ppv); ret = HeapAlloc(GetProcessHeap(), 0, sizeof(ResProtocol)); ret->lpInternetProtocolVtbl = &ResProtocolVtbl;