From b0218db90a8b3c81ecb6529c089acc0074cd2f9c Mon Sep 17 00:00:00 2001 From: Robert Shearman Date: Tue, 7 Feb 2006 16:26:23 +0100 Subject: [PATCH] oleaut32: Fix circular reference count in Typelib marshaler. The current method of handling typelib-marshaled interfaces that derive from IDispatch is to query for an IDispatch pointer from the proxy, but this causes a circular reference count. Fix the reference counting by loading using the IRpcProxyBuffer of IDispatch without an outer unknown, so that the lifetime is controlled by the typelib-marshaled interface's proxy. The IDispatch proxy now shares the same channel as the typelib-marshaled interface, so fix up the stub side to handle this. --- dlls/oleaut32/tmarshal.c | 153 ++++++++++++++++++++++++--------------- 1 file changed, 95 insertions(+), 58 deletions(-) diff --git a/dlls/oleaut32/tmarshal.c b/dlls/oleaut32/tmarshal.c index 1e1cea6b32d..9e4b612fc27 100644 --- a/dlls/oleaut32/tmarshal.c +++ b/dlls/oleaut32/tmarshal.c @@ -355,6 +355,7 @@ typedef struct _TMProxyImpl { CRITICAL_SECTION crit; IUnknown *outerunknown; IDispatch *dispatch; + IRpcProxyBuffer *dispatch_proxy; } TMProxyImpl; static HRESULT WINAPI @@ -391,7 +392,7 @@ TMProxyImpl_Release(LPRPCPROXYBUFFER iface) if (!refCount) { - if (This->dispatch) IDispatch_Release(This->dispatch); + if (This->dispatch_proxy) IRpcProxyBuffer_Release(This->dispatch_proxy); DeleteCriticalSection(&This->crit); if (This->chanbuf) IRpcChannelBuffer_Release(This->chanbuf); VirtualFree(This->asmstubs, 0, MEM_RELEASE); @@ -415,6 +416,9 @@ TMProxyImpl_Connect( LeaveCriticalSection(&This->crit); + if (This->dispatch_proxy) + IRpcProxyBuffer_Connect(This->dispatch_proxy, pRpcChannelBuffer); + return S_OK; } @@ -431,6 +435,9 @@ TMProxyImpl_Disconnect(LPRPCPROXYBUFFER iface) This->chanbuf = NULL; LeaveCriticalSection(&This->crit); + + if (This->dispatch_proxy) + IRpcProxyBuffer_Disconnect(This->dispatch_proxy); } @@ -1380,56 +1387,29 @@ ULONG WINAPI ProxyIUnknown_Release(IUnknown *iface) static HRESULT WINAPI ProxyIDispatch_GetTypeInfoCount(LPDISPATCH iface, UINT * pctinfo) { TMProxyImpl *This = (TMProxyImpl *)iface; - HRESULT hr; TRACE("(%p)\n", pctinfo); - if (!This->dispatch) - { - hr = IUnknown_QueryInterface(This->outerunknown, &IID_IDispatch, - (LPVOID *)&This->dispatch); - } - if (This->dispatch) - hr = IDispatch_GetTypeInfoCount(This->dispatch, pctinfo); - - return hr; + return IDispatch_GetTypeInfoCount(This->dispatch, pctinfo); } static HRESULT WINAPI ProxyIDispatch_GetTypeInfo(LPDISPATCH iface, UINT iTInfo, LCID lcid, ITypeInfo** ppTInfo) { TMProxyImpl *This = (TMProxyImpl *)iface; - HRESULT hr = S_OK; TRACE("(%d, %lx, %p)\n", iTInfo, lcid, ppTInfo); - if (!This->dispatch) - { - hr = IUnknown_QueryInterface(This->outerunknown, &IID_IDispatch, - (LPVOID *)&This->dispatch); - } - if (This->dispatch) - hr = IDispatch_GetTypeInfo(This->dispatch, iTInfo, lcid, ppTInfo); - - return hr; + return IDispatch_GetTypeInfo(This->dispatch, iTInfo, lcid, ppTInfo); } static HRESULT WINAPI ProxyIDispatch_GetIDsOfNames(LPDISPATCH iface, REFIID riid, LPOLESTR * rgszNames, UINT cNames, LCID lcid, DISPID * rgDispId) { TMProxyImpl *This = (TMProxyImpl *)iface; - HRESULT hr; TRACE("(%s, %p, %d, 0x%lx, %p)\n", debugstr_guid(riid), rgszNames, cNames, lcid, rgDispId); - if (!This->dispatch) - { - hr = IUnknown_QueryInterface(This->outerunknown, &IID_IDispatch, - (LPVOID *)&This->dispatch); - } - if (This->dispatch) - hr = IDispatch_GetIDsOfNames(This->dispatch, riid, rgszNames, - cNames, lcid, rgDispId); - - return hr; + return IDispatch_GetIDsOfNames(This->dispatch, riid, rgszNames, + cNames, lcid, rgDispId); } static HRESULT WINAPI ProxyIDispatch_Invoke(LPDISPATCH iface, DISPID dispIdMember, REFIID riid, LCID lcid, @@ -1437,21 +1417,25 @@ static HRESULT WINAPI ProxyIDispatch_Invoke(LPDISPATCH iface, DISPID dispIdMembe EXCEPINFO * pExcepInfo, UINT * puArgErr) { TMProxyImpl *This = (TMProxyImpl *)iface; - HRESULT hr; - TRACE("(%ld, %s, 0x%lx, 0x%x, %p, %p, %p, %p)\n", dispIdMember, debugstr_guid(riid), lcid, wFlags, pDispParams, pVarResult, pExcepInfo, puArgErr); + TRACE("(%ld, %s, 0x%lx, 0x%x, %p, %p, %p, %p)\n", dispIdMember, + debugstr_guid(riid), lcid, wFlags, pDispParams, pVarResult, + pExcepInfo, puArgErr); - if (!This->dispatch) - { - hr = IUnknown_QueryInterface(This->outerunknown, &IID_IDispatch, - (LPVOID *)&This->dispatch); - } - if (This->dispatch) - hr = IDispatch_Invoke(This->dispatch, dispIdMember, riid, lcid, - wFlags, pDispParams, pVarResult, pExcepInfo, - puArgErr); + return IDispatch_Invoke(This->dispatch, dispIdMember, riid, lcid, + wFlags, pDispParams, pVarResult, pExcepInfo, + puArgErr); +} - return hr; +static inline HRESULT get_facbuf_for_iid(REFIID riid, IPSFactoryBuffer **facbuf) +{ + HRESULT hr; + CLSID clsid; + + if ((hr = CoGetPSClsid(riid, &clsid))) + return hr; + return CoGetClassObject(&clsid, CLSCTX_INPROC_SERVER, NULL, + &IID_IPSFactoryBuffer, (LPVOID*)facbuf); } static HRESULT WINAPI @@ -1479,6 +1463,7 @@ PSFacBuf_CreateProxy( assert(sizeof(TMAsmProxy) == 12); proxy->dispatch = NULL; + proxy->dispatch_proxy = NULL; proxy->outerunknown = pUnkOuter; proxy->asmstubs = VirtualAlloc(NULL, sizeof(TMAsmProxy) * nroffuncs, MEM_COMMIT, PAGE_EXECUTE_READWRITE); if (!proxy->asmstubs) { @@ -1558,10 +1543,22 @@ PSFacBuf_CreateProxy( { if (typeattr->wTypeFlags & TYPEFLAG_FDISPATCHABLE) { - proxy->lpvtbl[3] = ProxyIDispatch_GetTypeInfoCount; - proxy->lpvtbl[4] = ProxyIDispatch_GetTypeInfo; - proxy->lpvtbl[5] = ProxyIDispatch_GetIDsOfNames; - proxy->lpvtbl[6] = ProxyIDispatch_Invoke; + IPSFactoryBuffer *factory_buffer; + hres = get_facbuf_for_iid(&IID_IDispatch, &factory_buffer); + if (hres == S_OK) + { + hres = IPSFactoryBuffer_CreateProxy(factory_buffer, NULL, + &IID_IDispatch, &proxy->dispatch_proxy, + (void **)&proxy->dispatch); + IPSFactoryBuffer_Release(factory_buffer); + } + if (hres == S_OK) + { + proxy->lpvtbl[3] = ProxyIDispatch_GetTypeInfoCount; + proxy->lpvtbl[4] = ProxyIDispatch_GetTypeInfo; + proxy->lpvtbl[5] = ProxyIDispatch_GetIDsOfNames; + proxy->lpvtbl[6] = ProxyIDispatch_Invoke; + } } ITypeInfo_ReleaseTypeAttr(tinfo, typeattr); } @@ -1572,10 +1569,16 @@ PSFacBuf_CreateProxy( proxy->tinfo = tinfo; memcpy(&proxy->iid,riid,sizeof(*riid)); proxy->chanbuf = 0; - *ppv = (LPVOID)proxy; - *ppProxy = (IRpcProxyBuffer *)&(proxy->lpvtbl2); - IUnknown_AddRef((IUnknown *)*ppv); - return S_OK; + if (hres == S_OK) + { + *ppv = (LPVOID)proxy; + *ppProxy = (IRpcProxyBuffer *)&(proxy->lpvtbl2); + IUnknown_AddRef((IUnknown *)*ppv); + return S_OK; + } + else + TMProxyImpl_Release((IRpcProxyBuffer *)&proxy->lpvtbl2); + return hres; } typedef struct _TMStubImpl { @@ -1585,6 +1588,7 @@ typedef struct _TMStubImpl { LPUNKNOWN pUnk; ITypeInfo *tinfo; IID iid; + IRpcStubBuffer *dispatch_stub; } TMStubImpl; static HRESULT WINAPI @@ -1636,6 +1640,10 @@ TMStubImpl_Connect(LPRPCSTUBBUFFER iface, LPUNKNOWN pUnkServer) IUnknown_AddRef(pUnkServer); This->pUnk = pUnkServer; + + if (This->dispatch_stub) + IRpcStubBuffer_Connect(This->dispatch_stub, pUnkServer); + return S_OK; } @@ -1651,6 +1659,9 @@ TMStubImpl_Disconnect(LPRPCSTUBBUFFER iface) IUnknown_Release(This->pUnk); This->pUnk = NULL; } + + if (This->dispatch_stub) + IRpcStubBuffer_Disconnect(This->dispatch_stub); } static HRESULT WINAPI @@ -1668,12 +1679,6 @@ TMStubImpl_Invoke( BSTR iname = NULL; ITypeInfo *tinfo; - memset(&buf,0,sizeof(buf)); - buf.size = xmsg->cbBuffer; - buf.base = HeapAlloc(GetProcessHeap(), 0, xmsg->cbBuffer); - memcpy(buf.base, xmsg->Buffer, xmsg->cbBuffer); - buf.curoff = 0; - TRACE("...\n"); if (xmsg->iMethod < 3) { @@ -1681,6 +1686,15 @@ TMStubImpl_Invoke( return E_UNEXPECTED; } + if (This->dispatch_stub && xmsg->iMethod < sizeof(IDispatchVtbl)/sizeof(void *)) + return IRpcStubBuffer_Invoke(This->dispatch_stub, xmsg, rpcchanbuf); + + memset(&buf,0,sizeof(buf)); + buf.size = xmsg->cbBuffer; + buf.base = HeapAlloc(GetProcessHeap(), 0, xmsg->cbBuffer); + memcpy(buf.base, xmsg->Buffer, xmsg->cbBuffer); + buf.curoff = 0; + hres = _get_funcdesc(This->tinfo,xmsg->iMethod,&tinfo,&fdesc,&iname,NULL); if (hres) { ERR("GetFuncDesc on method %ld failed with %lx\n",xmsg->iMethod,hres); @@ -1839,25 +1853,48 @@ PSFacBuf_CreateStub( HRESULT hres; ITypeInfo *tinfo; TMStubImpl *stub; + TYPEATTR *typeattr; TRACE("(%s,%p,%p)\n",debugstr_guid(riid),pUnkServer,ppStub); + hres = _get_typeinfo_for_iid(riid,&tinfo); if (hres) { ERR("No typeinfo for %s?\n",debugstr_guid(riid)); return hres; } + stub = CoTaskMemAlloc(sizeof(TMStubImpl)); if (!stub) return E_OUTOFMEMORY; stub->lpvtbl = &tmstubvtbl; stub->ref = 1; stub->tinfo = tinfo; + stub->dispatch_stub = NULL; memcpy(&(stub->iid),riid,sizeof(*riid)); hres = IRpcStubBuffer_Connect((LPRPCSTUBBUFFER)stub,pUnkServer); *ppStub = (LPRPCSTUBBUFFER)stub; TRACE("IRpcStubBuffer: %p\n", stub); if (hres) ERR("Connect to pUnkServer failed?\n"); + + /* if we derive from IDispatch then defer to its stub for some of its methods */ + hres = ITypeInfo_GetTypeAttr(tinfo, &typeattr); + if (hres == S_OK) + { + if (typeattr->wTypeFlags & TYPEFLAG_FDISPATCHABLE) + { + IPSFactoryBuffer *factory_buffer; + hres = get_facbuf_for_iid(&IID_IDispatch, &factory_buffer); + if (hres == S_OK) + { + hres = IPSFactoryBuffer_CreateStub(factory_buffer, &IID_IDispatch, + pUnkServer, &stub->dispatch_stub); + IPSFactoryBuffer_Release(factory_buffer); + } + } + ITypeInfo_ReleaseTypeAttr(tinfo, typeattr); + } + return hres; }