diff --git a/dlls/urlmon/binding.c b/dlls/urlmon/binding.c index 17bc0f214e6..aaadbd87638 100644 --- a/dlls/urlmon/binding.c +++ b/dlls/urlmon/binding.c @@ -45,6 +45,15 @@ typedef struct _task_header_t { struct _task_header_t *next; } task_header_t; +typedef struct { + const IHttpNegotiate2Vtbl *lpHttpNegotiate2Vtbl; + + LONG ref; + + IHttpNegotiate *http_negotiate; + IHttpNegotiate2 *http_negotiate2; +} HttpNegotiate2Wrapper; + typedef struct { const IStreamVtbl *lpStreamVtbl; @@ -76,6 +85,7 @@ struct Binding { IInternetProtocol *protocol; IServiceProvider *service_provider; ProtocolStream *stream; + HttpNegotiate2Wrapper *httpneg2_wrapper; BINDINFO bindinfo; DWORD bindf; @@ -101,6 +111,7 @@ struct Binding { #define SERVPROV(x) ((IServiceProvider*) &(x)->lpServiceProviderVtbl) #define STREAM(x) ((IStream*) &(x)->lpStreamVtbl) +#define HTTPNEG2(x) ((IHttpNegotiate2*) &(x)->lpHttpNegotiate2Vtbl) #define WM_MK_CONTINUE (WM_USER+101) @@ -248,84 +259,153 @@ static void dump_BINDINFO(BINDINFO *bi) ); } -static HRESULT WINAPI HttpNegotiate_QueryInterface(IHttpNegotiate2 *iface, - REFIID riid, void **ppv) +#define HTTPNEG2_THIS(iface) DEFINE_THIS(HttpNegotiate2Wrapper, HttpNegotiate2, iface) + +static HRESULT WINAPI HttpNegotiate2Wrapper_QueryInterface(IHttpNegotiate2 *iface, + REFIID riid, void **ppv) { + HttpNegotiate2Wrapper *This = HTTPNEG2_THIS(iface); + *ppv = NULL; if(IsEqualGUID(&IID_IUnknown, riid)) { TRACE("(IID_IUnknown %p)\n", ppv); - *ppv = iface; + *ppv = HTTPNEG2(This); }else if(IsEqualGUID(&IID_IHttpNegotiate, riid)) { TRACE("(IID_IHttpNegotiate %p)\n", ppv); - *ppv = iface; + *ppv = HTTPNEG2(This); }else if(IsEqualGUID(&IID_IHttpNegotiate2, riid)) { TRACE("(IID_IHttpNegotiate2 %p)\n", ppv); - *ppv = iface; + *ppv = HTTPNEG2(This); } if(*ppv) { - IHttpNegotiate2_AddRef(iface); + IHttpNegotiate2_AddRef(HTTPNEG2(This)); return S_OK; } - WARN("Unsupported interface %s\n", debugstr_guid(riid)); + WARN("(%p)->(%s %p)\n", This, debugstr_guid(riid), ppv); return E_NOINTERFACE; } -static ULONG WINAPI HttpNegotiate_AddRef(IHttpNegotiate2 *iface) +static ULONG WINAPI HttpNegotiate2Wrapper_AddRef(IHttpNegotiate2 *iface) { - URLMON_LockModule(); - return 2; + HttpNegotiate2Wrapper *This = HTTPNEG2_THIS(iface); + LONG ref = InterlockedIncrement(&This->ref); + + TRACE("(%p) ref=%d\n", This, ref); + + return ref; } -static ULONG WINAPI HttpNegotiate_Release(IHttpNegotiate2 *iface) +static ULONG WINAPI HttpNegotiate2Wrapper_Release(IHttpNegotiate2 *iface) { - URLMON_UnlockModule(); - return 1; + HttpNegotiate2Wrapper *This = HTTPNEG2_THIS(iface); + LONG ref = InterlockedDecrement(&This->ref); + + TRACE("(%p) ref=%d\n", This, ref); + + if(!ref) { + if (This->http_negotiate) + IHttpNegotiate_Release(This->http_negotiate); + if (This->http_negotiate2) + IHttpNegotiate2_Release(This->http_negotiate2); + HeapFree(GetProcessHeap(), 0, This); + + URLMON_UnlockModule(); + } + + return ref; } -static HRESULT WINAPI HttpNegotiate_BeginningTransaction(IHttpNegotiate2 *iface, +static HRESULT WINAPI HttpNegotiate2Wrapper_BeginningTransaction(IHttpNegotiate2 *iface, LPCWSTR szURL, LPCWSTR szHeaders, DWORD dwReserved, LPWSTR *pszAdditionalHeaders) { - TRACE("(%s %s %d %p)\n", debugstr_w(szURL), debugstr_w(szHeaders), dwReserved, + HttpNegotiate2Wrapper *This = HTTPNEG2_THIS(iface); + + TRACE("(%p)->(%s %s %d %p)\n", This, debugstr_w(szURL), debugstr_w(szHeaders), dwReserved, pszAdditionalHeaders); + if(This->http_negotiate) + return IHttpNegotiate_BeginningTransaction(This->http_negotiate, szURL, szHeaders, + dwReserved, pszAdditionalHeaders); + *pszAdditionalHeaders = NULL; return S_OK; } -static HRESULT WINAPI HttpNegotiate_OnResponse(IHttpNegotiate2 *iface, DWORD dwResponseCode, +static HRESULT WINAPI HttpNegotiate2Wrapper_OnResponse(IHttpNegotiate2 *iface, DWORD dwResponseCode, LPCWSTR szResponseHeaders, LPCWSTR szRequestHeaders, LPWSTR *pszAdditionalRequestHeaders) { - TRACE("(%d %s %s %p)\n", dwResponseCode, debugstr_w(szResponseHeaders), + HttpNegotiate2Wrapper *This = HTTPNEG2_THIS(iface); + LPWSTR szAdditionalRequestHeaders = NULL; + HRESULT hres = S_OK; + + TRACE("(%p)->(%d %s %s %p)\n", This, dwResponseCode, debugstr_w(szResponseHeaders), debugstr_w(szRequestHeaders), pszAdditionalRequestHeaders); - if(pszAdditionalRequestHeaders) + /* IHttpNegotiate2_OnResponse expects pszAdditionalHeaders to be non-NULL when it is + * implemented as part of IBindStatusCallback, but it is NULL when called directly from + * IProtocol */ + if(!pszAdditionalRequestHeaders) + pszAdditionalRequestHeaders = &szAdditionalRequestHeaders; + + if(This->http_negotiate) + { + hres = IHttpNegotiate_OnResponse(This->http_negotiate, dwResponseCode, szResponseHeaders, + szRequestHeaders, pszAdditionalRequestHeaders); + if(pszAdditionalRequestHeaders == &szAdditionalRequestHeaders && + szAdditionalRequestHeaders) + CoTaskMemFree(szAdditionalRequestHeaders); + } + else + { *pszAdditionalRequestHeaders = NULL; - return S_OK; + } + + return hres; } -static HRESULT WINAPI HttpNegotiate_GetRootSecurityId(IHttpNegotiate2 *iface, +static HRESULT WINAPI HttpNegotiate2Wrapper_GetRootSecurityId(IHttpNegotiate2 *iface, BYTE *pbSecurityId, DWORD *pcbSecurityId, DWORD_PTR dwReserved) { - TRACE("(%p %p %ld)\n", pbSecurityId, pcbSecurityId, dwReserved); + HttpNegotiate2Wrapper *This = HTTPNEG2_THIS(iface); + + TRACE("(%p)->(%p %p %ld)\n", This, pbSecurityId, pcbSecurityId, dwReserved); + + if (This->http_negotiate2) + return IHttpNegotiate2_GetRootSecurityId(This->http_negotiate2, pbSecurityId, + pcbSecurityId, dwReserved); /* That's all we have to do here */ return E_FAIL; } -static const IHttpNegotiate2Vtbl HttpNegotiate2Vtbl = { - HttpNegotiate_QueryInterface, - HttpNegotiate_AddRef, - HttpNegotiate_Release, - HttpNegotiate_BeginningTransaction, - HttpNegotiate_OnResponse, - HttpNegotiate_GetRootSecurityId +#undef HTTPNEG2_THIS + +static const IHttpNegotiate2Vtbl HttpNegotiate2WrapperVtbl = { + HttpNegotiate2Wrapper_QueryInterface, + HttpNegotiate2Wrapper_AddRef, + HttpNegotiate2Wrapper_Release, + HttpNegotiate2Wrapper_BeginningTransaction, + HttpNegotiate2Wrapper_OnResponse, + HttpNegotiate2Wrapper_GetRootSecurityId }; -static IHttpNegotiate2 HttpNegotiate = { &HttpNegotiate2Vtbl }; +static HttpNegotiate2Wrapper *create_httpneg2_wrapper(void) +{ + HttpNegotiate2Wrapper *ret = HeapAlloc(GetProcessHeap(), 0, sizeof(HttpNegotiate2Wrapper)); + + ret->lpHttpNegotiate2Vtbl = &HttpNegotiate2WrapperVtbl; + ret->ref = 1; + ret->http_negotiate = NULL; + ret->http_negotiate2 = NULL; + + URLMON_LockModule(); + + return ret; +} #define STREAM_THIS(iface) DEFINE_THIS(ProtocolStream, Stream, iface) @@ -601,6 +681,8 @@ static ULONG WINAPI Binding_Release(IBinding *iface) IServiceProvider_Release(This->service_provider); if(This->stream) IStream_Release(STREAM(This->stream)); + if(This->httpneg2_wrapper) + IHttpNegotiate2_Release(HTTPNEG2(This->httpneg2_wrapper)); ReleaseBindInfo(&This->bindinfo); This->section.DebugInfo->Spare[0] = 0; @@ -1092,8 +1174,20 @@ static HRESULT WINAPI ServiceProvider_QueryService(IServiceProvider *iface, } if(IsEqualGUID(&IID_IHttpNegotiate, guidService) - || IsEqualGUID(&IID_IHttpNegotiate2, guidService)) - return IHttpNegotiate2_QueryInterface(&HttpNegotiate, riid, ppv); + || IsEqualGUID(&IID_IHttpNegotiate2, guidService)) { + if(!This->httpneg2_wrapper) { + WARN("HttpNegotiate2Wrapper expected to be non-NULL\n"); + } else { + if(IsEqualGUID(&IID_IHttpNegotiate, guidService)) + IBindStatusCallback_QueryInterface(This->callback, riid, + (void **)&This->httpneg2_wrapper->http_negotiate); + else + IBindStatusCallback_QueryInterface(This->callback, riid, + (void **)&This->httpneg2_wrapper->http_negotiate2); + + return IHttpNegotiate2_QueryInterface(HTTPNEG2(This->httpneg2_wrapper), riid, ppv); + } + } WARN("unknown service %s\n", debugstr_guid(guidService)); return E_NOTIMPL; @@ -1208,6 +1302,7 @@ static HRESULT Binding_Create(LPCWSTR url, IBindCtx *pbc, REFIID riid, Binding * ret->protocol = NULL; ret->service_provider = NULL; ret->stream = NULL; + ret->httpneg2_wrapper = NULL; ret->mime = NULL; ret->url = NULL; ret->apartment_thread = GetCurrentThreadId(); @@ -1265,6 +1360,8 @@ static HRESULT Binding_Create(LPCWSTR url, IBindCtx *pbc, REFIID riid, Binding * ret->stgmed.u.pstm = STREAM(ret->stream); ret->stgmed.pUnkForRelease = (IUnknown*)BINDING(ret); /* NOTE: Windows uses other IUnknown */ + ret->httpneg2_wrapper = create_httpneg2_wrapper(); + *binding = ret; return S_OK; }