diff --git a/dlls/urlmon/tests/protocol.c b/dlls/urlmon/tests/protocol.c index f5543fe7ac9..0ca8feebcb1 100644 --- a/dlls/urlmon/tests/protocol.c +++ b/dlls/urlmon/tests/protocol.c @@ -20,6 +20,7 @@ #define CONST_VTABLE #include +#include #include #include @@ -126,6 +127,7 @@ DEFINE_EXPECT(Stream_Seek); DEFINE_EXPECT(Stream_Read); DEFINE_EXPECT(Redirect); DEFINE_EXPECT(outer_QI_test); +DEFINE_EXPECT(Protocol_destructor); static const WCHAR wszIndexHtml[] = {'i','n','d','e','x','.','h','t','m','l',0}; static const WCHAR index_url[] = @@ -175,6 +177,17 @@ static enum { BIND_TEST } tested_protocol; +typedef struct { + IUnknown IUnknown_inner; + IInternetProtocolEx IInternetProtocolEx_iface; + IInternetPriority IInternetPriority_iface; + IUnknown *outer; + LONG inner_ref; + LONG outer_ref; +} Protocol; + +static Protocol *protocol_emul; + static const WCHAR protocol_names[][10] = { {'f','i','l','e',0}, {'h','t','t','p',0}, @@ -1504,8 +1517,10 @@ static IInternetBindInfoVtbl bind_info_vtbl = { static IInternetBindInfo bind_info = { &bind_info_vtbl }; -static IUnknown *protocol_outer; -static ULONG protocol_outer_ref; +static Protocol *impl_from_IInternetPriority(IInternetPriority *iface) +{ + return CONTAINING_RECORD(iface, Protocol, IInternetPriority_iface); +} static HRESULT WINAPI InternetPriority_QueryInterface(IInternetPriority *iface, REFIID riid, void **ppv) @@ -1516,12 +1531,16 @@ static HRESULT WINAPI InternetPriority_QueryInterface(IInternetPriority *iface, static ULONG WINAPI InternetPriority_AddRef(IInternetPriority *iface) { - return ++protocol_outer_ref; + Protocol *This = impl_from_IInternetPriority(iface); + This->outer_ref++; + return IUnknown_AddRef(This->outer); } static ULONG WINAPI InternetPriority_Release(IInternetPriority *iface) { - return --protocol_outer_ref; + Protocol *This = impl_from_IInternetPriority(iface); + This->outer_ref--; + return IUnknown_Release(This->outer); } static HRESULT WINAPI InternetPriority_SetPriority(IInternetPriority *iface, LONG nPriority) @@ -1546,16 +1565,14 @@ static const IInternetPriorityVtbl InternetPriorityVtbl = { InternetPriority_GetPriority }; -static IInternetPriority InternetPriority = { &InternetPriorityVtbl }; - static ULONG WINAPI Protocol_AddRef(IInternetProtocolEx *iface) { - return ++protocol_outer_ref; + return 2; } static ULONG WINAPI Protocol_Release(IInternetProtocolEx *iface) { - return --protocol_outer_ref; + return 1; } static HRESULT WINAPI Protocol_Abort(IInternetProtocolEx *iface, HRESULT hrReason, @@ -1592,15 +1609,22 @@ static HRESULT WINAPI Protocol_Seek(IInternetProtocolEx *iface, return E_NOTIMPL; } +static Protocol *impl_from_IInternetProtocolEx(IInternetProtocolEx *iface) +{ + return CONTAINING_RECORD(iface, Protocol, IInternetProtocolEx_iface); +} + static HRESULT WINAPI ProtocolEmul_QueryInterface(IInternetProtocolEx *iface, REFIID riid, void **ppv) { + Protocol *This = impl_from_IInternetProtocolEx(iface); + static const IID unknown_iid = {0x7daf9908,0x8415,0x4005,{0x95,0xae, 0xbd,0x27,0xf6,0xe3,0xdc,0x00}}; static const IID unknown_iid2 = {0x5b7ebc0c,0xf630,0x4cea,{0x89,0xd3,0x5a,0xf0,0x38,0xed,0x05,0x5c}}; /* FIXME: Why is it calling here instead of outer IUnknown? */ if(IsEqualGUID(riid, &IID_IInternetPriority)) { - *ppv = &InternetPriority; - IInternetPriority_AddRef(&InternetPriority); + *ppv = &This->IInternetPriority_iface; + IInternetPriority_AddRef(&This->IInternetPriority_iface); return S_OK; } if(!IsEqualGUID(riid, &unknown_iid) && !IsEqualGUID(riid, &unknown_iid2)) /* IE10 */ @@ -1609,6 +1633,20 @@ static HRESULT WINAPI ProtocolEmul_QueryInterface(IInternetProtocolEx *iface, RE return E_NOINTERFACE; } +static ULONG WINAPI ProtocolEmul_AddRef(IInternetProtocolEx *iface) +{ + Protocol *This = impl_from_IInternetProtocolEx(iface); + This->outer_ref++; + return IUnknown_AddRef(This->outer); +} + +static ULONG WINAPI ProtocolEmul_Release(IInternetProtocolEx *iface) +{ + Protocol *This = impl_from_IInternetProtocolEx(iface); + This->outer_ref--; + return IUnknown_Release(This->outer); +} + static DWORD WINAPI thread_proc(PVOID arg) { BOOL redirect = redirect_on_continue; @@ -1917,6 +1955,7 @@ static HRESULT WINAPI ProtocolEmul_Continue(IInternetProtocolEx *iface, SET_EXPECT(Redirect); SET_EXPECT(ReportProgress_REDIRECTING); SET_EXPECT(Terminate); + SET_EXPECT(Protocol_destructor); SET_EXPECT(QueryService_InternetProtocol); SET_EXPECT(CreateInstance); SET_EXPECT(ReportProgress_PROTOCOLCLASSID); @@ -1928,6 +1967,7 @@ static HRESULT WINAPI ProtocolEmul_Continue(IInternetProtocolEx *iface, CHECK_CALLED(Redirect); CHECK_CALLED(ReportProgress_REDIRECTING); CHECK_CALLED(Terminate); + CHECK_CALLED(Protocol_destructor); CHECK_CALLED(QueryService_InternetProtocol); CHECK_CALLED(CreateInstance); CHECK_CALLED(ReportProgress_PROTOCOLCLASSID); @@ -2153,8 +2193,8 @@ static HRESULT WINAPI ProtocolEmul_StartEx(IInternetProtocolEx *iface, IUri *pUr static const IInternetProtocolExVtbl ProtocolVtbl = { ProtocolEmul_QueryInterface, - Protocol_AddRef, - Protocol_Release, + ProtocolEmul_AddRef, + ProtocolEmul_Release, ProtocolEmul_Start, ProtocolEmul_Continue, Protocol_Abort, @@ -2168,28 +2208,31 @@ static const IInternetProtocolExVtbl ProtocolVtbl = { ProtocolEmul_StartEx }; -static IInternetProtocolEx Protocol = { &ProtocolVtbl }; -static ULONG protocol_inner_ref; +static Protocol *impl_from_IUnknown(IUnknown *iface) +{ + return CONTAINING_RECORD(iface, Protocol, IUnknown_inner); +} static HRESULT WINAPI ProtocolUnk_QueryInterface(IUnknown *iface, REFIID riid, void **ppv) { + Protocol *This = impl_from_IUnknown(iface); + if(IsEqualGUID(&IID_IUnknown, riid)) { trace("QI(IUnknown)\n"); - *ppv = iface; + *ppv = &This->IUnknown_inner; }else if(IsEqualGUID(&IID_IInternetProtocol, riid)) { trace("QI(InternetProtocol)\n"); - *ppv = &Protocol; + *ppv = &This->IInternetProtocolEx_iface; }else if(IsEqualGUID(&IID_IInternetProtocolEx, riid)) { trace("QI(InternetProtocolEx)\n"); if(!impl_protex) { *ppv = NULL; return E_NOINTERFACE; } - *ppv = &Protocol; - return S_OK; + *ppv = &This->IInternetProtocolEx_iface; }else if(IsEqualGUID(&IID_IInternetPriority, riid)) { trace("QI(InternetPriority)\n"); - *ppv = &InternetPriority; + *ppv = &This->IInternetPriority_iface; }else if(IsEqualGUID(&IID_IWinInetInfo, riid)) { trace("QI(IWinInetInfo)\n"); CHECK_EXPECT(QueryInterface_IWinInetInfo); @@ -2212,12 +2255,26 @@ static HRESULT WINAPI ProtocolUnk_QueryInterface(IUnknown *iface, REFIID riid, v static ULONG WINAPI ProtocolUnk_AddRef(IUnknown *iface) { - return ++protocol_inner_ref; + Protocol *This = impl_from_IUnknown(iface); + return ++This->inner_ref; } static ULONG WINAPI ProtocolUnk_Release(IUnknown *iface) { - return --protocol_inner_ref; + Protocol *This = impl_from_IUnknown(iface); + LONG ref = --This->inner_ref; + if(!ref) { + /* IE9 is broken on redirects. It will cause -1 outer_ref on original protocol handler + * and 1 on redirected handler. */ + ok(!This->outer_ref + || broken(test_redirect && (This->outer_ref == -1 || This->outer_ref == 1)), + "outer_ref = %d\n", This->outer_ref); + if(This->outer_ref) + trace("outer_ref %d\n", This->outer_ref); + CHECK_EXPECT(Protocol_destructor); + heap_free(This); + } + return ref; } static const IUnknownVtbl ProtocolUnkVtbl = { @@ -2226,8 +2283,6 @@ static const IUnknownVtbl ProtocolUnkVtbl = { ProtocolUnk_Release }; -static IUnknown ProtocolUnk = { &ProtocolUnkVtbl }; - static HRESULT WINAPI MimeProtocol_QueryInterface(IInternetProtocolEx *iface, REFIID riid, void **ppv) { if(IsEqualGUID(&IID_IUnknown, riid) || IsEqualGUID(&IID_IInternetProtocol, riid)) { @@ -2512,15 +2567,24 @@ static ULONG WINAPI ClassFactory_Release(IClassFactory *iface) static HRESULT WINAPI ClassFactory_CreateInstance(IClassFactory *iface, IUnknown *pOuter, REFIID riid, void **ppv) { + Protocol *ret; + CHECK_EXPECT(CreateInstance); ok(pOuter == (IUnknown*)prot_bind_info, "pOuter != protocol_unk\n"); ok(IsEqualGUID(&IID_IUnknown, riid), "unexpected riid %s\n", wine_dbgstr_guid(riid)); ok(ppv != NULL, "ppv == NULL\n"); - protocol_outer = pOuter; - *ppv = &ProtocolUnk; - IUnknown_AddRef(&ProtocolUnk); + ret = heap_alloc(sizeof(*ret)); + ret->IUnknown_inner.lpVtbl = &ProtocolUnkVtbl; + ret->IInternetProtocolEx_iface.lpVtbl = &ProtocolVtbl; + ret->IInternetPriority_iface.lpVtbl = &InternetPriorityVtbl; + ret->outer = pOuter; + ret->inner_ref = 1; + ret->outer_ref = 0; + + protocol_emul = ret; + *ppv = &ret->IUnknown_inner; return S_OK; } @@ -3708,7 +3772,6 @@ static void test_CreateBinding(void) hres = IInternetSession_RegisterNameSpace(session, &ClassFactory, &IID_NULL, wsz_test, 0, NULL, 0); ok(hres == S_OK, "RegisterNameSpace failed: %08x\n", hres); - protocol_inner_ref = 0; hres = IInternetSession_CreateBinding(session, NULL, test_url, NULL, NULL, &protocol, 0); binding_protocol = protocol; ok(hres == S_OK, "CreateBinding failed: %08x\n", hres); @@ -3759,8 +3822,6 @@ static void test_CreateBinding(void) ok(hres == S_OK, "Start failed: %08x\n", hres); trace("Start <\n"); - ok(protocol_inner_ref == 1, "protocol_inner_ref = %u\n", protocol_inner_ref); - CHECK_CALLED(QueryService_InternetProtocol); CHECK_CALLED(CreateInstance); CHECK_CALLED(ReportProgress_PROTOCOLCLASSID); @@ -3809,8 +3870,7 @@ static void test_CreateBinding(void) ok(hres == S_OK, "Terminate failed: %08x\n", hres); CHECK_CALLED(Terminate); - ok(protocol_inner_ref == 1, "protocol_inner_ref = %u\n", protocol_inner_ref); - ok(protocol_outer_ref == 0, "protocol_outer_ref = %u\n", protocol_outer_ref); + ok(protocol_emul->outer_ref == 0, "protocol_outer_ref = %u\n", protocol_emul->outer_ref); SET_EXPECT(Continue); hres = IInternetProtocolSink_Switch(binding_sink, &protocoldata); @@ -3840,10 +3900,10 @@ static void test_CreateBinding(void) IInternetProtocolSink_Release(binding_sink); IInternetPriority_Release(priority); IInternetBindInfo_Release(prot_bind_info); - IInternetProtocol_Release(protocol); - ok(protocol_inner_ref == 0, "protocol_inner_ref = %u\n", protocol_inner_ref); - ok(protocol_outer_ref == 0, "protocol_outer_ref = %u\n", protocol_outer_ref); + SET_EXPECT(Protocol_destructor); + IInternetProtocol_Release(protocol); + CHECK_CALLED(Protocol_destructor); hres = IInternetSession_CreateBinding(session, NULL, test_url, NULL, NULL, &protocol, 0); ok(hres == S_OK, "CreateBinding failed: %08x\n", hres); @@ -4015,8 +4075,11 @@ static void test_binding(int prot, DWORD grf_pi, DWORD test_flags) IInternetProtocol_Release(filtered_protocol); IInternetBindInfo_Release(prot_bind_info); IInternetProtocolSink_Release(binding_sink); + + SET_EXPECT(Protocol_destructor); ref = IInternetProtocol_Release(protocol); ok(!ref, "ref=%u, expected 0\n", ref); + CHECK_CALLED(Protocol_destructor); if(test_flags & TEST_EMULATEPROT) { hres = IInternetSession_UnregisterNameSpace(session, &ClassFactory, protocol_names[prot]);