oleaut32: Fix circular reference count in Typelib marshaler.

The current method of handling typelib-marshaled interfaces that derive
from IDispatch is to query for an IDispatch pointer from the proxy, but
this causes a circular reference count.
Fix the reference counting by loading using the IRpcProxyBuffer of
IDispatch without an outer unknown, so that the lifetime is controlled
by the typelib-marshaled interface's proxy. The IDispatch proxy now
shares the same channel as the typelib-marshaled interface, so fix up
the stub side to handle this.
This commit is contained in:
Robert Shearman 2006-02-07 16:26:23 +01:00 committed by Alexandre Julliard
parent fd81d9c56e
commit b0218db90a

View File

@ -355,6 +355,7 @@ typedef struct _TMProxyImpl {
CRITICAL_SECTION crit; CRITICAL_SECTION crit;
IUnknown *outerunknown; IUnknown *outerunknown;
IDispatch *dispatch; IDispatch *dispatch;
IRpcProxyBuffer *dispatch_proxy;
} TMProxyImpl; } TMProxyImpl;
static HRESULT WINAPI static HRESULT WINAPI
@ -391,7 +392,7 @@ TMProxyImpl_Release(LPRPCPROXYBUFFER iface)
if (!refCount) if (!refCount)
{ {
if (This->dispatch) IDispatch_Release(This->dispatch); if (This->dispatch_proxy) IRpcProxyBuffer_Release(This->dispatch_proxy);
DeleteCriticalSection(&This->crit); DeleteCriticalSection(&This->crit);
if (This->chanbuf) IRpcChannelBuffer_Release(This->chanbuf); if (This->chanbuf) IRpcChannelBuffer_Release(This->chanbuf);
VirtualFree(This->asmstubs, 0, MEM_RELEASE); VirtualFree(This->asmstubs, 0, MEM_RELEASE);
@ -415,6 +416,9 @@ TMProxyImpl_Connect(
LeaveCriticalSection(&This->crit); LeaveCriticalSection(&This->crit);
if (This->dispatch_proxy)
IRpcProxyBuffer_Connect(This->dispatch_proxy, pRpcChannelBuffer);
return S_OK; return S_OK;
} }
@ -431,6 +435,9 @@ TMProxyImpl_Disconnect(LPRPCPROXYBUFFER iface)
This->chanbuf = NULL; This->chanbuf = NULL;
LeaveCriticalSection(&This->crit); LeaveCriticalSection(&This->crit);
if (This->dispatch_proxy)
IRpcProxyBuffer_Disconnect(This->dispatch_proxy);
} }
@ -1380,56 +1387,29 @@ ULONG WINAPI ProxyIUnknown_Release(IUnknown *iface)
static HRESULT WINAPI ProxyIDispatch_GetTypeInfoCount(LPDISPATCH iface, UINT * pctinfo) static HRESULT WINAPI ProxyIDispatch_GetTypeInfoCount(LPDISPATCH iface, UINT * pctinfo)
{ {
TMProxyImpl *This = (TMProxyImpl *)iface; TMProxyImpl *This = (TMProxyImpl *)iface;
HRESULT hr;
TRACE("(%p)\n", pctinfo); TRACE("(%p)\n", pctinfo);
if (!This->dispatch) return IDispatch_GetTypeInfoCount(This->dispatch, pctinfo);
{
hr = IUnknown_QueryInterface(This->outerunknown, &IID_IDispatch,
(LPVOID *)&This->dispatch);
}
if (This->dispatch)
hr = IDispatch_GetTypeInfoCount(This->dispatch, pctinfo);
return hr;
} }
static HRESULT WINAPI ProxyIDispatch_GetTypeInfo(LPDISPATCH iface, UINT iTInfo, LCID lcid, ITypeInfo** ppTInfo) static HRESULT WINAPI ProxyIDispatch_GetTypeInfo(LPDISPATCH iface, UINT iTInfo, LCID lcid, ITypeInfo** ppTInfo)
{ {
TMProxyImpl *This = (TMProxyImpl *)iface; TMProxyImpl *This = (TMProxyImpl *)iface;
HRESULT hr = S_OK;
TRACE("(%d, %lx, %p)\n", iTInfo, lcid, ppTInfo); TRACE("(%d, %lx, %p)\n", iTInfo, lcid, ppTInfo);
if (!This->dispatch) return IDispatch_GetTypeInfo(This->dispatch, iTInfo, lcid, ppTInfo);
{
hr = IUnknown_QueryInterface(This->outerunknown, &IID_IDispatch,
(LPVOID *)&This->dispatch);
}
if (This->dispatch)
hr = IDispatch_GetTypeInfo(This->dispatch, iTInfo, lcid, ppTInfo);
return hr;
} }
static HRESULT WINAPI ProxyIDispatch_GetIDsOfNames(LPDISPATCH iface, REFIID riid, LPOLESTR * rgszNames, UINT cNames, LCID lcid, DISPID * rgDispId) static HRESULT WINAPI ProxyIDispatch_GetIDsOfNames(LPDISPATCH iface, REFIID riid, LPOLESTR * rgszNames, UINT cNames, LCID lcid, DISPID * rgDispId)
{ {
TMProxyImpl *This = (TMProxyImpl *)iface; TMProxyImpl *This = (TMProxyImpl *)iface;
HRESULT hr;
TRACE("(%s, %p, %d, 0x%lx, %p)\n", debugstr_guid(riid), rgszNames, cNames, lcid, rgDispId); TRACE("(%s, %p, %d, 0x%lx, %p)\n", debugstr_guid(riid), rgszNames, cNames, lcid, rgDispId);
if (!This->dispatch) return IDispatch_GetIDsOfNames(This->dispatch, riid, rgszNames,
{ cNames, lcid, rgDispId);
hr = IUnknown_QueryInterface(This->outerunknown, &IID_IDispatch,
(LPVOID *)&This->dispatch);
}
if (This->dispatch)
hr = IDispatch_GetIDsOfNames(This->dispatch, riid, rgszNames,
cNames, lcid, rgDispId);
return hr;
} }
static HRESULT WINAPI ProxyIDispatch_Invoke(LPDISPATCH iface, DISPID dispIdMember, REFIID riid, LCID lcid, static HRESULT WINAPI ProxyIDispatch_Invoke(LPDISPATCH iface, DISPID dispIdMember, REFIID riid, LCID lcid,
@ -1437,21 +1417,25 @@ static HRESULT WINAPI ProxyIDispatch_Invoke(LPDISPATCH iface, DISPID dispIdMembe
EXCEPINFO * pExcepInfo, UINT * puArgErr) EXCEPINFO * pExcepInfo, UINT * puArgErr)
{ {
TMProxyImpl *This = (TMProxyImpl *)iface; TMProxyImpl *This = (TMProxyImpl *)iface;
HRESULT hr;
TRACE("(%ld, %s, 0x%lx, 0x%x, %p, %p, %p, %p)\n", dispIdMember, debugstr_guid(riid), lcid, wFlags, pDispParams, pVarResult, pExcepInfo, puArgErr); TRACE("(%ld, %s, 0x%lx, 0x%x, %p, %p, %p, %p)\n", dispIdMember,
debugstr_guid(riid), lcid, wFlags, pDispParams, pVarResult,
pExcepInfo, puArgErr);
if (!This->dispatch) return IDispatch_Invoke(This->dispatch, dispIdMember, riid, lcid,
{ wFlags, pDispParams, pVarResult, pExcepInfo,
hr = IUnknown_QueryInterface(This->outerunknown, &IID_IDispatch, puArgErr);
(LPVOID *)&This->dispatch); }
}
if (This->dispatch)
hr = IDispatch_Invoke(This->dispatch, dispIdMember, riid, lcid,
wFlags, pDispParams, pVarResult, pExcepInfo,
puArgErr);
return hr; static inline HRESULT get_facbuf_for_iid(REFIID riid, IPSFactoryBuffer **facbuf)
{
HRESULT hr;
CLSID clsid;
if ((hr = CoGetPSClsid(riid, &clsid)))
return hr;
return CoGetClassObject(&clsid, CLSCTX_INPROC_SERVER, NULL,
&IID_IPSFactoryBuffer, (LPVOID*)facbuf);
} }
static HRESULT WINAPI static HRESULT WINAPI
@ -1479,6 +1463,7 @@ PSFacBuf_CreateProxy(
assert(sizeof(TMAsmProxy) == 12); assert(sizeof(TMAsmProxy) == 12);
proxy->dispatch = NULL; proxy->dispatch = NULL;
proxy->dispatch_proxy = NULL;
proxy->outerunknown = pUnkOuter; proxy->outerunknown = pUnkOuter;
proxy->asmstubs = VirtualAlloc(NULL, sizeof(TMAsmProxy) * nroffuncs, MEM_COMMIT, PAGE_EXECUTE_READWRITE); proxy->asmstubs = VirtualAlloc(NULL, sizeof(TMAsmProxy) * nroffuncs, MEM_COMMIT, PAGE_EXECUTE_READWRITE);
if (!proxy->asmstubs) { if (!proxy->asmstubs) {
@ -1558,10 +1543,22 @@ PSFacBuf_CreateProxy(
{ {
if (typeattr->wTypeFlags & TYPEFLAG_FDISPATCHABLE) if (typeattr->wTypeFlags & TYPEFLAG_FDISPATCHABLE)
{ {
proxy->lpvtbl[3] = ProxyIDispatch_GetTypeInfoCount; IPSFactoryBuffer *factory_buffer;
proxy->lpvtbl[4] = ProxyIDispatch_GetTypeInfo; hres = get_facbuf_for_iid(&IID_IDispatch, &factory_buffer);
proxy->lpvtbl[5] = ProxyIDispatch_GetIDsOfNames; if (hres == S_OK)
proxy->lpvtbl[6] = ProxyIDispatch_Invoke; {
hres = IPSFactoryBuffer_CreateProxy(factory_buffer, NULL,
&IID_IDispatch, &proxy->dispatch_proxy,
(void **)&proxy->dispatch);
IPSFactoryBuffer_Release(factory_buffer);
}
if (hres == S_OK)
{
proxy->lpvtbl[3] = ProxyIDispatch_GetTypeInfoCount;
proxy->lpvtbl[4] = ProxyIDispatch_GetTypeInfo;
proxy->lpvtbl[5] = ProxyIDispatch_GetIDsOfNames;
proxy->lpvtbl[6] = ProxyIDispatch_Invoke;
}
} }
ITypeInfo_ReleaseTypeAttr(tinfo, typeattr); ITypeInfo_ReleaseTypeAttr(tinfo, typeattr);
} }
@ -1572,10 +1569,16 @@ PSFacBuf_CreateProxy(
proxy->tinfo = tinfo; proxy->tinfo = tinfo;
memcpy(&proxy->iid,riid,sizeof(*riid)); memcpy(&proxy->iid,riid,sizeof(*riid));
proxy->chanbuf = 0; proxy->chanbuf = 0;
*ppv = (LPVOID)proxy; if (hres == S_OK)
*ppProxy = (IRpcProxyBuffer *)&(proxy->lpvtbl2); {
IUnknown_AddRef((IUnknown *)*ppv); *ppv = (LPVOID)proxy;
return S_OK; *ppProxy = (IRpcProxyBuffer *)&(proxy->lpvtbl2);
IUnknown_AddRef((IUnknown *)*ppv);
return S_OK;
}
else
TMProxyImpl_Release((IRpcProxyBuffer *)&proxy->lpvtbl2);
return hres;
} }
typedef struct _TMStubImpl { typedef struct _TMStubImpl {
@ -1585,6 +1588,7 @@ typedef struct _TMStubImpl {
LPUNKNOWN pUnk; LPUNKNOWN pUnk;
ITypeInfo *tinfo; ITypeInfo *tinfo;
IID iid; IID iid;
IRpcStubBuffer *dispatch_stub;
} TMStubImpl; } TMStubImpl;
static HRESULT WINAPI static HRESULT WINAPI
@ -1636,6 +1640,10 @@ TMStubImpl_Connect(LPRPCSTUBBUFFER iface, LPUNKNOWN pUnkServer)
IUnknown_AddRef(pUnkServer); IUnknown_AddRef(pUnkServer);
This->pUnk = pUnkServer; This->pUnk = pUnkServer;
if (This->dispatch_stub)
IRpcStubBuffer_Connect(This->dispatch_stub, pUnkServer);
return S_OK; return S_OK;
} }
@ -1651,6 +1659,9 @@ TMStubImpl_Disconnect(LPRPCSTUBBUFFER iface)
IUnknown_Release(This->pUnk); IUnknown_Release(This->pUnk);
This->pUnk = NULL; This->pUnk = NULL;
} }
if (This->dispatch_stub)
IRpcStubBuffer_Disconnect(This->dispatch_stub);
} }
static HRESULT WINAPI static HRESULT WINAPI
@ -1668,12 +1679,6 @@ TMStubImpl_Invoke(
BSTR iname = NULL; BSTR iname = NULL;
ITypeInfo *tinfo; ITypeInfo *tinfo;
memset(&buf,0,sizeof(buf));
buf.size = xmsg->cbBuffer;
buf.base = HeapAlloc(GetProcessHeap(), 0, xmsg->cbBuffer);
memcpy(buf.base, xmsg->Buffer, xmsg->cbBuffer);
buf.curoff = 0;
TRACE("...\n"); TRACE("...\n");
if (xmsg->iMethod < 3) { if (xmsg->iMethod < 3) {
@ -1681,6 +1686,15 @@ TMStubImpl_Invoke(
return E_UNEXPECTED; return E_UNEXPECTED;
} }
if (This->dispatch_stub && xmsg->iMethod < sizeof(IDispatchVtbl)/sizeof(void *))
return IRpcStubBuffer_Invoke(This->dispatch_stub, xmsg, rpcchanbuf);
memset(&buf,0,sizeof(buf));
buf.size = xmsg->cbBuffer;
buf.base = HeapAlloc(GetProcessHeap(), 0, xmsg->cbBuffer);
memcpy(buf.base, xmsg->Buffer, xmsg->cbBuffer);
buf.curoff = 0;
hres = _get_funcdesc(This->tinfo,xmsg->iMethod,&tinfo,&fdesc,&iname,NULL); hres = _get_funcdesc(This->tinfo,xmsg->iMethod,&tinfo,&fdesc,&iname,NULL);
if (hres) { if (hres) {
ERR("GetFuncDesc on method %ld failed with %lx\n",xmsg->iMethod,hres); ERR("GetFuncDesc on method %ld failed with %lx\n",xmsg->iMethod,hres);
@ -1839,25 +1853,48 @@ PSFacBuf_CreateStub(
HRESULT hres; HRESULT hres;
ITypeInfo *tinfo; ITypeInfo *tinfo;
TMStubImpl *stub; TMStubImpl *stub;
TYPEATTR *typeattr;
TRACE("(%s,%p,%p)\n",debugstr_guid(riid),pUnkServer,ppStub); TRACE("(%s,%p,%p)\n",debugstr_guid(riid),pUnkServer,ppStub);
hres = _get_typeinfo_for_iid(riid,&tinfo); hres = _get_typeinfo_for_iid(riid,&tinfo);
if (hres) { if (hres) {
ERR("No typeinfo for %s?\n",debugstr_guid(riid)); ERR("No typeinfo for %s?\n",debugstr_guid(riid));
return hres; return hres;
} }
stub = CoTaskMemAlloc(sizeof(TMStubImpl)); stub = CoTaskMemAlloc(sizeof(TMStubImpl));
if (!stub) if (!stub)
return E_OUTOFMEMORY; return E_OUTOFMEMORY;
stub->lpvtbl = &tmstubvtbl; stub->lpvtbl = &tmstubvtbl;
stub->ref = 1; stub->ref = 1;
stub->tinfo = tinfo; stub->tinfo = tinfo;
stub->dispatch_stub = NULL;
memcpy(&(stub->iid),riid,sizeof(*riid)); memcpy(&(stub->iid),riid,sizeof(*riid));
hres = IRpcStubBuffer_Connect((LPRPCSTUBBUFFER)stub,pUnkServer); hres = IRpcStubBuffer_Connect((LPRPCSTUBBUFFER)stub,pUnkServer);
*ppStub = (LPRPCSTUBBUFFER)stub; *ppStub = (LPRPCSTUBBUFFER)stub;
TRACE("IRpcStubBuffer: %p\n", stub); TRACE("IRpcStubBuffer: %p\n", stub);
if (hres) if (hres)
ERR("Connect to pUnkServer failed?\n"); ERR("Connect to pUnkServer failed?\n");
/* if we derive from IDispatch then defer to its stub for some of its methods */
hres = ITypeInfo_GetTypeAttr(tinfo, &typeattr);
if (hres == S_OK)
{
if (typeattr->wTypeFlags & TYPEFLAG_FDISPATCHABLE)
{
IPSFactoryBuffer *factory_buffer;
hres = get_facbuf_for_iid(&IID_IDispatch, &factory_buffer);
if (hres == S_OK)
{
hres = IPSFactoryBuffer_CreateStub(factory_buffer, &IID_IDispatch,
pUnkServer, &stub->dispatch_stub);
IPSFactoryBuffer_Release(factory_buffer);
}
}
ITypeInfo_ReleaseTypeAttr(tinfo, typeattr);
}
return hres; return hres;
} }