itss: Support COM aggregation in its protocol handler.

Signed-off-by: Jacek Caban <jacek@codeweavers.com>
Signed-off-by: Alexandre Julliard <julliard@winehq.org>
This commit is contained in:
Jacek Caban 2018-05-22 11:31:59 +02:00 committed by Alexandre Julliard
parent 7305e5fd8c
commit cd5570d9ef
3 changed files with 138 additions and 31 deletions

View File

@ -105,20 +105,31 @@ static ULONG WINAPI ITSSCF_Release(LPCLASSFACTORY iface)
}
static HRESULT WINAPI ITSSCF_CreateInstance(LPCLASSFACTORY iface, LPUNKNOWN pOuter,
REFIID riid, LPVOID *ppobj)
static HRESULT WINAPI ITSSCF_CreateInstance(IClassFactory *iface, IUnknown *outer,
REFIID riid, void **ppv)
{
IClassFactoryImpl *This = impl_from_IClassFactory(iface);
IUnknown *unk;
HRESULT hres;
LPUNKNOWN punk;
TRACE("(%p)->(%p,%s,%p)\n", This, pOuter, debugstr_guid(riid), ppobj);
TRACE("(%p)->(%p %s %p)\n", This, outer, debugstr_guid(riid), ppv);
*ppobj = NULL;
hres = This->pfnCreateInstance(pOuter, (LPVOID *) &punk);
if (SUCCEEDED(hres)) {
hres = IUnknown_QueryInterface(punk, riid, ppobj);
IUnknown_Release(punk);
if(outer && !IsEqualGUID(riid, &IID_IUnknown)) {
*ppv = NULL;
return CLASS_E_NOAGGREGATION;
}
hres = This->pfnCreateInstance(outer, (void**)&unk);
if(FAILED(hres)) {
*ppv = NULL;
return hres;
}
if(!IsEqualGUID(riid, &IID_IUnknown)) {
hres = IUnknown_QueryInterface(unk, riid, ppv);
IUnknown_Release(unk);
}else {
*ppv = unk;
}
return hres;
}

View File

@ -36,16 +36,23 @@
WINE_DEFAULT_DEBUG_CHANNEL(itss);
typedef struct {
IUnknown IUnknown_inner;
IInternetProtocol IInternetProtocol_iface;
IInternetProtocolInfo IInternetProtocolInfo_iface;
LONG ref;
IUnknown *outer;
ULONG offset;
struct chmFile *chm_file;
struct chmUnitInfo chm_object;
} ITSProtocol;
static inline ITSProtocol *impl_from_IUnknown(IUnknown *iface)
{
return CONTAINING_RECORD(iface, ITSProtocol, IUnknown_inner);
}
static inline ITSProtocol *impl_from_IInternetProtocol(IInternetProtocol *iface)
{
return CONTAINING_RECORD(iface, ITSProtocol, IInternetProtocol_iface);
@ -65,14 +72,13 @@ static void release_chm(ITSProtocol *This)
This->offset = 0;
}
static HRESULT WINAPI ITSProtocol_QueryInterface(IInternetProtocol *iface, REFIID riid, void **ppv)
static HRESULT WINAPI ITSProtocol_QueryInterface(IUnknown *iface, REFIID riid, void **ppv)
{
ITSProtocol *This = impl_from_IInternetProtocol(iface);
ITSProtocol *This = impl_from_IUnknown(iface);
*ppv = NULL;
if(IsEqualGUID(&IID_IUnknown, riid)) {
TRACE("(%p)->(IID_IUnknown %p)\n", This, ppv);
*ppv = &This->IInternetProtocol_iface;
*ppv = &This->IUnknown_inner;
}else if(IsEqualGUID(&IID_IInternetProtocolRoot, riid)) {
TRACE("(%p)->(IID_IInternetProtocolRoot %p)\n", This, ppv);
*ppv = &This->IInternetProtocol_iface;
@ -82,28 +88,27 @@ static HRESULT WINAPI ITSProtocol_QueryInterface(IInternetProtocol *iface, REFII
}else if(IsEqualGUID(&IID_IInternetProtocolInfo, riid)) {
TRACE("(%p)->(IID_IInternetProtocolInfo %p)\n", This, ppv);
*ppv = &This->IInternetProtocolInfo_iface;
}else {
*ppv = NULL;
WARN("not supported interface %s\n", debugstr_guid(riid));
return E_NOINTERFACE;
}
if(*ppv) {
IInternetProtocol_AddRef(iface);
return S_OK;
}
WARN("not supported interface %s\n", debugstr_guid(riid));
return E_NOINTERFACE;
IUnknown_AddRef((IUnknown*)*ppv);
return S_OK;
}
static ULONG WINAPI ITSProtocol_AddRef(IInternetProtocol *iface)
static ULONG WINAPI ITSProtocol_AddRef(IUnknown *iface)
{
ITSProtocol *This = impl_from_IInternetProtocol(iface);
ITSProtocol *This = impl_from_IUnknown(iface);
LONG ref = InterlockedIncrement(&This->ref);
TRACE("(%p) ref=%d\n", This, ref);
return ref;
}
static ULONG WINAPI ITSProtocol_Release(IInternetProtocol *iface)
static ULONG WINAPI ITSProtocol_Release(IUnknown *iface)
{
ITSProtocol *This = impl_from_IInternetProtocol(iface);
ITSProtocol *This = impl_from_IUnknown(iface);
LONG ref = InterlockedDecrement(&This->ref);
TRACE("(%p) ref=%d\n", This, ref);
@ -118,6 +123,30 @@ static ULONG WINAPI ITSProtocol_Release(IInternetProtocol *iface)
return ref;
}
static const IUnknownVtbl ITSProtocolUnkVtbl = {
ITSProtocol_QueryInterface,
ITSProtocol_AddRef,
ITSProtocol_Release
};
static HRESULT WINAPI ITSInternetProtocol_QueryInterface(IInternetProtocol *iface, REFIID riid, void **ppv)
{
ITSProtocol *This = impl_from_IInternetProtocol(iface);
return IUnknown_QueryInterface(This->outer, riid, ppv);
}
static ULONG WINAPI ITSInternetProtocol_AddRef(IInternetProtocol *iface)
{
ITSProtocol *This = impl_from_IInternetProtocol(iface);
return IUnknown_AddRef(This->outer);
}
static ULONG WINAPI ITSInternetProtocol_Release(IInternetProtocol *iface)
{
ITSProtocol *This = impl_from_IInternetProtocol(iface);
return IUnknown_Release(This->outer);
}
static LPCWSTR skip_schema(LPCWSTR url)
{
static const WCHAR its_schema[] = {'i','t','s',':'};
@ -387,9 +416,9 @@ static HRESULT WINAPI ITSProtocol_UnlockRequest(IInternetProtocol *iface)
}
static const IInternetProtocolVtbl ITSProtocolVtbl = {
ITSProtocol_QueryInterface,
ITSProtocol_AddRef,
ITSProtocol_Release,
ITSInternetProtocol_QueryInterface,
ITSInternetProtocol_AddRef,
ITSInternetProtocol_Release,
ITSProtocol_Start,
ITSProtocol_Continue,
ITSProtocol_Abort,
@ -520,21 +549,24 @@ static const IInternetProtocolInfoVtbl ITSProtocolInfoVtbl = {
ITSProtocolInfo_QueryInfo
};
HRESULT ITSProtocol_create(IUnknown *pUnkOuter, LPVOID *ppobj)
HRESULT ITSProtocol_create(IUnknown *outer, void **ppv)
{
ITSProtocol *ret;
TRACE("(%p %p)\n", pUnkOuter, ppobj);
TRACE("(%p %p)\n", outer, ppv);
ITSS_LockModule();
ret = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(ITSProtocol));
if(!ret)
return E_OUTOFMEMORY;
ret->IUnknown_inner.lpVtbl = &ITSProtocolUnkVtbl;
ret->IInternetProtocol_iface.lpVtbl = &ITSProtocolVtbl;
ret->IInternetProtocolInfo_iface.lpVtbl = &ITSProtocolInfoVtbl;
ret->ref = 1;
ret->outer = outer ? outer : &ret->IUnknown_inner;
*ppobj = &ret->IInternetProtocol_iface;
*ppv = &ret->IUnknown_inner;
return S_OK;
}

View File

@ -60,6 +60,7 @@ DEFINE_EXPECT(ReportProgress_CACHEFILENAMEAVAILABLE);
DEFINE_EXPECT(ReportProgress_DIRECTBIND);
DEFINE_EXPECT(ReportData);
DEFINE_EXPECT(ReportResult);
DEFINE_EXPECT(outer_QI_test);
static HRESULT expect_hrResult;
static IInternetProtocol *read_protocol = NULL;
@ -660,6 +661,68 @@ static void delete_chm(void)
ok(ret, "DeleteFileA failed: %d\n", GetLastError());
}
static const IID outer_test_iid = {0xabcabc00,0,0,{0,0,0,0,0,0,0,0x66}};
static HRESULT WINAPI outer_QueryInterface(IUnknown *iface, REFIID riid, void **ppv)
{
if(IsEqualGUID(riid, &outer_test_iid)) {
CHECK_EXPECT(outer_QI_test);
*ppv = (IUnknown*)0xdeadbeef;
return S_OK;
}
ok(0, "unexpected call %s\n", wine_dbgstr_guid(riid));
return E_NOINTERFACE;
}
static ULONG WINAPI outer_AddRef(IUnknown *iface)
{
return 2;
}
static ULONG WINAPI outer_Release(IUnknown *iface)
{
return 1;
}
static const IUnknownVtbl outer_vtbl = {
outer_QueryInterface,
outer_AddRef,
outer_Release
};
static void test_com_aggregation(const CLSID *clsid)
{
IUnknown outer = { &outer_vtbl };
IClassFactory *class_factory;
IUnknown *unk, *unk2, *unk3;
HRESULT hres;
hres = CoGetClassObject(clsid, CLSCTX_INPROC_SERVER, NULL, &IID_IClassFactory, (void**)&class_factory);
ok(hres == S_OK, "CoGetClassObject failed: %08x\n", hres);
hres = IClassFactory_CreateInstance(class_factory, &outer, &IID_IUnknown, (void**)&unk);
ok(hres == S_OK, "CreateInstance returned: %08x\n", hres);
hres = IUnknown_QueryInterface(unk, &IID_IInternetProtocol, (void**)&unk2);
ok(hres == S_OK, "Could not get IInternetProtocol iface: %08x\n", hres);
SET_EXPECT(outer_QI_test);
hres = IUnknown_QueryInterface(unk2, &outer_test_iid, (void**)&unk3);
CHECK_CALLED(outer_QI_test);
ok(hres == S_OK, "Could not get IInternetProtocol iface: %08x\n", hres);
ok(unk3 == (IUnknown*)0xdeadbeef, "unexpected unk2\n");
IUnknown_Release(unk2);
IUnknown_Release(unk);
unk = (void*)0xdeadbeef;
hres = IClassFactory_CreateInstance(class_factory, &outer, &IID_IInternetProtocol, (void**)&unk);
ok(hres == CLASS_E_NOAGGREGATION, "CreateInstance returned: %08x\n", hres);
ok(!unk, "unk = %p\n", unk);
IClassFactory_Release(class_factory);
}
START_TEST(protocol)
{
OleInitialize(NULL);
@ -669,6 +732,7 @@ START_TEST(protocol)
test_its_protocol();
test_mk_protocol();
test_com_aggregation(&CLSID_ITSProtocol);
delete_chm();
OleUninitialize();