diff --git a/dlls/urlmon/bindprot.c b/dlls/urlmon/bindprot.c index 4b8069f39bd..3d332c280e2 100644 --- a/dlls/urlmon/bindprot.c +++ b/dlls/urlmon/bindprot.c @@ -681,13 +681,24 @@ static HRESULT WINAPI ProtocolHandler_Start(IInternetProtocol *iface, LPCWSTR sz static HRESULT WINAPI ProtocolHandler_Continue(IInternetProtocol *iface, PROTOCOLDATA *pProtocolData) { BindProtocol *This = impl_from_IInternetProtocol(iface); + IInternetProtocol *protocol; HRESULT hres; TRACE("(%p)->(%p)\n", This, pProtocolData); - hres = IInternetProtocol_Continue(This->protocol, pProtocolData); + /* FIXME: This should not be needed. */ + if(!This->protocol && This->protocol_unk) { + hres = IUnknown_QueryInterface(This->protocol_unk, &IID_IInternetProtocol, (void**)&protocol); + if(FAILED(hres)) + return E_FAIL; + }else { + IInternetProtocol_AddRef(protocol = This->protocol); + } + + hres = IInternetProtocol_Continue(protocol, pProtocolData); heap_free(pProtocolData); + IInternetProtocol_Release(protocol); return hres; } @@ -716,7 +727,11 @@ static HRESULT WINAPI ProtocolHandler_Terminate(IInternetProtocol *iface, DWORD /* This may get released in Terminate call. */ IInternetProtocolEx_AddRef(&This->IInternetProtocolEx_iface); - IInternetProtocol_Terminate(This->protocol, 0); + if(This->protocol) { + IInternetProtocol_Terminate(This->protocol, 0); + IInternetProtocol_Release(This->protocol); + This->protocol = NULL; + } set_binding_sink(This, NULL, NULL); @@ -772,14 +787,28 @@ static HRESULT WINAPI ProtocolHandler_Read(IInternetProtocol *iface, void *pv, } if(read < cb) { + IInternetProtocol *protocol; ULONG cread = 0; + /* FIXME: We shouldn't need it, but out binding code currently depends on it. */ + if(!This->protocol && This->protocol_unk) { + hres = IUnknown_QueryInterface(This->protocol_unk, &IID_IInternetProtocol, + (void**)&protocol); + if(FAILED(hres)) + return E_ABORT; + }else { + protocol = This->protocol; + } + if(is_apartment_thread(This)) This->continue_call++; - hres = IInternetProtocol_Read(This->protocol, (BYTE*)pv+read, cb-read, &cread); + hres = IInternetProtocol_Read(protocol, (BYTE*)pv+read, cb-read, &cread); if(is_apartment_thread(This)) This->continue_call--; read += cread; + + if(!This->protocol) + IInternetProtocol_Release(protocol); } *pcbRead = read; diff --git a/dlls/urlmon/tests/protocol.c b/dlls/urlmon/tests/protocol.c index 4ca43272426..c153119ce43 100644 --- a/dlls/urlmon/tests/protocol.c +++ b/dlls/urlmon/tests/protocol.c @@ -3758,6 +3758,16 @@ static void test_CreateBinding(void) ok(hres == S_OK, "Switch failed: %08x\n", hres); CHECK_CALLED(Continue); + SET_EXPECT(Read); + read = 0xdeadbeef; + hres = IInternetProtocol_Read(protocol, expect_pv = buf, sizeof(buf), &read); + todo_wine + ok(hres == E_ABORT, "Read failed: %08x\n", hres); + todo_wine + ok(read == 0, "read = %d\n", read); + todo_wine + CHECK_NOT_CALLED(Read); + hres = IInternetProtocolSink_ReportProgress(binding_sink, BINDSTATUS_CACHEFILENAMEAVAILABLE, expect_wsz = emptyW); ok(hres == S_OK, "ReportProgress(BINDSTATUS_CACHEFILENAMEAVAILABLE) failed: %08x\n", hres);