From 385e693e4422383d55de2fc8433abf577c3abab2 Mon Sep 17 00:00:00 2001 From: Robert Shearman Date: Fri, 27 Jan 2006 12:54:22 +0100 Subject: [PATCH] ole: Test and implement LPSAFEARRAY marshaling. --- dlls/oleaut32/oleaut32.spec | 8 +- dlls/oleaut32/tests/Makefile.in | 1 + dlls/oleaut32/tests/usrmarshal.c | 125 ++++++++++ dlls/oleaut32/usrmarshal.c | 412 +++++++++++++++++++++++++++++++ 4 files changed, 542 insertions(+), 4 deletions(-) create mode 100644 dlls/oleaut32/tests/usrmarshal.c diff --git a/dlls/oleaut32/oleaut32.spec b/dlls/oleaut32/oleaut32.spec index 2d4c3f65ba6..9da60ec1558 100644 --- a/dlls/oleaut32/oleaut32.spec +++ b/dlls/oleaut32/oleaut32.spec @@ -285,10 +285,10 @@ 288 stdcall VARIANT_UserMarshal(ptr ptr ptr) 289 stdcall VARIANT_UserUnmarshal(ptr ptr ptr) 290 stdcall VARIANT_UserFree(ptr ptr) -291 stub LPSAFEARRAY_UserSize -292 stub LPSAFEARRAY_UserMarshal -293 stub LPSAFEARRAY_UserUnmarshal -294 stub LPSAFEARRAY_UserFree +291 stdcall LPSAFEARRAY_UserSize(ptr long ptr) +292 stdcall LPSAFEARRAY_UserMarshal(ptr ptr ptr) +293 stdcall LPSAFEARRAY_UserUnmarshal(ptr ptr ptr) +294 stdcall LPSAFEARRAY_UserFree(ptr ptr) 295 stub LPSAFEARRAY_Size 296 stub LPSAFEARRAY_Marshal 297 stub LPSAFEARRAY_Unmarshal diff --git a/dlls/oleaut32/tests/Makefile.in b/dlls/oleaut32/tests/Makefile.in index dbd80667737..c0ad3261973 100644 --- a/dlls/oleaut32/tests/Makefile.in +++ b/dlls/oleaut32/tests/Makefile.in @@ -11,6 +11,7 @@ CTESTS = \ olepicture.c \ safearray.c \ typelib.c \ + usrmarshal.c \ vartest.c \ vartype.c diff --git a/dlls/oleaut32/tests/usrmarshal.c b/dlls/oleaut32/tests/usrmarshal.c new file mode 100644 index 00000000000..b5ca2fb3c08 --- /dev/null +++ b/dlls/oleaut32/tests/usrmarshal.c @@ -0,0 +1,125 @@ +/* + * Marshaling Tests + * + * Copyright 2004 Robert Shearman + * + * This library is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * This library is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with this library; if not, write to the Free Software + * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + */ + +#include + +#include "windef.h" +#include "winbase.h" +#include "objbase.h" +#include "propidl.h" /* for LPSAFEARRAY_User* routines */ + +#include "wine/test.h" + +/* doesn't work on Windows due to needing more of the + * MIDL_STUB_MESSAGE structure to be filled out */ +#define LPSAFEARRAY_UNMARSHAL_WORKS 0 + +static void test_marshal_LPSAFEARRAY(void) +{ + unsigned char *buffer; + unsigned long size; + LPSAFEARRAY lpsa; + LPSAFEARRAY lpsa2 = NULL; + unsigned char *wiresa; + SAFEARRAYBOUND sab; + MIDL_STUB_MESSAGE stubMsg = { 0 }; + USER_MARSHAL_CB umcb = { 0 }; + + umcb.Flags = MAKELONG(MSHCTX_DIFFERENTMACHINE, NDR_LOCAL_DATA_REPRESENTATION); + umcb.pReserve = NULL; + umcb.pStubMsg = &stubMsg; + + sab.lLbound = 5; + sab.cElements = 10; + + lpsa = SafeArrayCreate(VT_I2, 1, &sab); + *(DWORD *)lpsa->pvData = 0xcafebabe; + + lpsa->cLocks = 7; + size = LPSAFEARRAY_UserSize(&umcb.Flags, 0, &lpsa); + ok(size == 64, "size should be 64 bytes, not %ld\n", size); + buffer = (unsigned char *)HeapAlloc(GetProcessHeap(), 0, size); + LPSAFEARRAY_UserMarshal(&umcb.Flags, buffer, &lpsa); + wiresa = buffer; + ok(*(DWORD *)wiresa == TRUE, "wiresa + 0x0 should be TRUE instead of 0x%08lx\n", *(DWORD *)wiresa); + wiresa += sizeof(DWORD); + ok(*(DWORD *)wiresa == lpsa->cDims, "wiresa + 0x4 should be lpsa->cDims instead of 0x%08lx\n", *(DWORD *)wiresa); + wiresa += sizeof(DWORD); + ok(*(WORD *)wiresa == lpsa->cDims, "wiresa + 0x8 should be lpsa->cDims instead of 0x%04x\n", *(WORD *)wiresa); + wiresa += sizeof(WORD); + ok(*(WORD *)wiresa == lpsa->fFeatures, "wiresa + 0xc should be lpsa->fFeatures instead of 0x%08x\n", *(WORD *)wiresa); + wiresa += sizeof(WORD); + ok(*(DWORD *)wiresa == lpsa->cbElements, "wiresa + 0x10 should be lpsa->cbElements instead of 0x%08lx\n", *(DWORD *)wiresa); + wiresa += sizeof(DWORD); + ok(*(WORD *)wiresa == lpsa->cLocks, "wiresa + 0x16 should be lpsa->cLocks instead of 0x%04x\n", *(WORD *)wiresa); + wiresa += sizeof(WORD); + ok(*(WORD *)wiresa == VT_I2, "wiresa + 0x14 should be VT_I2 instead of 0x%04x\n", *(WORD *)wiresa); + wiresa += sizeof(WORD); + ok(*(DWORD *)wiresa == VT_I2, "wiresa + 0x18 should be VT_I2 instead of 0x%08lx\n", *(DWORD *)wiresa); + wiresa += sizeof(DWORD); + ok(*(DWORD *)wiresa == sab.cElements, "wiresa + 0x1c should be sab.cElements instead of %lu\n", *(DWORD *)wiresa); + wiresa += sizeof(DWORD); + ok(*(DWORD_PTR *)wiresa == (DWORD_PTR)lpsa->pvData, "wirestgm + 0x20 should be lpsa->pvData instead of 0x%08lx\n", *(DWORD_PTR *)wiresa); + wiresa += sizeof(DWORD_PTR); + ok(*(DWORD *)wiresa == sab.cElements, "wiresa + 0x24 should be sab.cElements instead of %lu\n", *(DWORD *)wiresa); + wiresa += sizeof(DWORD); + ok(*(LONG *)wiresa == sab.lLbound, "wiresa + 0x28 should be sab.clLbound instead of %ld\n", *(LONG *)wiresa); + wiresa += sizeof(LONG); + ok(*(DWORD *)wiresa == sab.cElements, "wiresa + 0x2c should be sab.cElements instead of %lu\n", *(DWORD *)wiresa); + wiresa += sizeof(DWORD); + /* elements are now pointed to by wiresa */ + + if (LPSAFEARRAY_UNMARSHAL_WORKS) + { + LPSAFEARRAY_UserUnmarshal(&umcb.Flags, buffer, &lpsa2); + ok(lpsa2 != NULL, "LPSAFEARRAY didn't unmarshal\n"); + LPSAFEARRAY_UserFree(&umcb.Flags, &lpsa2); + } + HeapFree(GetProcessHeap(), 0, buffer); + SafeArrayDestroy(lpsa); + + /* test NULL safe array */ + lpsa = NULL; + + size = LPSAFEARRAY_UserSize(&umcb.Flags, 0, &lpsa); + ok(size == 4, "size should be 4 bytes, not %ld\n", size); + buffer = (unsigned char *)HeapAlloc(GetProcessHeap(), 0, size); + LPSAFEARRAY_UserMarshal(&umcb.Flags, buffer, &lpsa); + wiresa = buffer; + ok(*(DWORD *)wiresa == FALSE, "wiresa + 0x0 should be FALSE instead of 0x%08lx\n", *(DWORD *)wiresa); + wiresa += sizeof(DWORD); + + if (LPSAFEARRAY_UNMARSHAL_WORKS) + { + LPSAFEARRAY_UserUnmarshal(&umcb.Flags, buffer, &lpsa2); + ok(lpsa2 == NULL, "NULL LPSAFEARRAY didn't unmarshal\n"); + LPSAFEARRAY_UserFree(&umcb.Flags, &lpsa2); + } + HeapFree(GetProcessHeap(), 0, buffer); +} + +START_TEST(usrmarshal) +{ + CoInitialize(NULL); + + test_marshal_LPSAFEARRAY(); + + CoUninitialize(); +} diff --git a/dlls/oleaut32/usrmarshal.c b/dlls/oleaut32/usrmarshal.c index 9e942da8e15..242c3b1960a 100644 --- a/dlls/oleaut32/usrmarshal.c +++ b/dlls/oleaut32/usrmarshal.c @@ -62,6 +62,22 @@ HRESULT OLEAUTPS_DllGetClassObject(REFCLSID rclsid, REFIID riid, LPVOID *ppv) &CLSID_PSDispatch, &PSFactoryBuffer); } +static void dump_user_flags(unsigned long *pFlags) +{ + if (HIWORD(*pFlags) == NDR_LOCAL_DATA_REPRESENTATION) + TRACE("MAKELONG(NDR_LOCAL_REPRESENTATION, "); + else + TRACE("MAKELONG(0x%04x, ", HIWORD(*pFlags)); + switch (LOWORD(*pFlags)) + { + case MSHCTX_LOCAL: TRACE("MSHCTX_LOCAL)"); break; + case MSHCTX_NOSHAREDMEM: TRACE("MSHCTX_NOSHAREDMEM)"); break; + case MSHCTX_DIFFERENTMACHINE: TRACE("MSHCTX_DIFFERENTMACHINE)"); break; + case MSHCTX_INPROC: TRACE("MSHCTX_INPROC)"); break; + default: TRACE("%d)", LOWORD(*pFlags)); + } +} + /* CLEANLOCALSTORAGE */ /* I'm not sure how this is supposed to work yet */ @@ -516,6 +532,402 @@ void WINAPI VARIANT_UserFree(unsigned long *pFlags, VARIANT *pvar) CoTaskMemFree(ref); } +/* LPSAFEARRAY */ + +/* Get the number of cells in a SafeArray */ +static ULONG SAFEARRAY_GetCellCount(const SAFEARRAY *psa) +{ + const SAFEARRAYBOUND* psab = psa->rgsabound; + USHORT cCount = psa->cDims; + ULONG ulNumCells = 1; + + while (cCount--) + { + /* This is a valid bordercase. See testcases. -Marcus */ + if (!psab->cElements) + return 0; + ulNumCells *= psab->cElements; + psab++; + } + return ulNumCells; +} + +static inline SF_TYPE SAFEARRAY_GetUnionType(SAFEARRAY *psa) +{ + VARTYPE vt; + HRESULT hr; + + hr = SafeArrayGetVartype(psa, &vt); + if (FAILED(hr)) + RpcRaiseException(hr); + + if (psa->fFeatures & FADF_HAVEIID) + return SF_HAVEIID; + + switch (vt) + { + case VT_I1: + case VT_UI1: return SF_I1; + case VT_BOOL: + case VT_I2: + case VT_UI2: return SF_I2; + case VT_INT: + case VT_UINT: + case VT_I4: + case VT_UI4: + case VT_R4: return SF_I4; + case VT_DATE: + case VT_CY: + case VT_R8: + case VT_I8: + case VT_UI8: return SF_I8; + case VT_INT_PTR: + case VT_UINT_PTR: return (sizeof(UINT_PTR) == 4 ? SF_I4 : SF_I8); + case VT_BSTR: return SF_BSTR; + case VT_DISPATCH: return SF_DISPATCH; + case VT_VARIANT: return SF_VARIANT; + case VT_UNKNOWN: return SF_UNKNOWN; + /* Note: Return a non-zero size to indicate vt is valid. The actual size + * of a UDT is taken from the result of IRecordInfo_GetSize(). + */ + case VT_RECORD: return SF_RECORD; + default: return SF_ERROR; + } +} + +unsigned long WINAPI LPSAFEARRAY_UserSize(unsigned long *pFlags, unsigned long StartingSize, LPSAFEARRAY *ppsa) +{ + unsigned long size = StartingSize; + + TRACE("("); dump_user_flags(pFlags); TRACE(", %ld, %p\n", StartingSize, *ppsa); + + size += sizeof(ULONG_PTR); + if (*ppsa) + { + SAFEARRAY *psa = *ppsa; + ULONG ulCellCount = SAFEARRAY_GetCellCount(psa); + SF_TYPE sftype; + HRESULT hr; + + size += sizeof(ULONG); + size += FIELD_OFFSET(struct _wireSAFEARRAY, uArrayStructs); + + sftype = SAFEARRAY_GetUnionType(psa); + size += sizeof(ULONG); + + size += sizeof(ULONG); + size += sizeof(ULONG_PTR); + if (sftype == SF_HAVEIID) + size += sizeof(IID); + + size += sizeof(psa->rgsabound[0]) * psa->cDims; + + size += sizeof(ULONG); + + switch (sftype) + { + case SF_BSTR: + { + BSTR* lpBstr; + + for (lpBstr = (BSTR*)psa->pvData; ulCellCount; ulCellCount--, lpBstr++) + size = BSTR_UserSize(pFlags, size, lpBstr); + + break; + } + case SF_DISPATCH: + case SF_UNKNOWN: + case SF_HAVEIID: + FIXME("size interfaces\n"); + break; + case SF_VARIANT: + { + VARIANT* lpVariant; + + for (lpVariant = (VARIANT*)psa->pvData; ulCellCount; ulCellCount--, lpVariant++) + size = VARIANT_UserSize(pFlags, size, lpVariant); + + break; + } + case SF_RECORD: + { + IRecordInfo* pRecInfo = NULL; + + hr = SafeArrayGetRecordInfo(psa, &pRecInfo); + if (FAILED(hr)) + RpcRaiseException(hr); + + if (pRecInfo) + { + FIXME("size record info %p\n", pRecInfo); + + IRecordInfo_Release(pRecInfo); + } + break; + } + case SF_I1: + case SF_I2: + case SF_I4: + case SF_I8: + size += ulCellCount * psa->cbElements; + break; + default: + break; + } + + } + + return size; +} + +unsigned char * WINAPI LPSAFEARRAY_UserMarshal(unsigned long *pFlags, unsigned char *Buffer, LPSAFEARRAY *ppsa) +{ + HRESULT hr; + + TRACE("("); dump_user_flags(pFlags); TRACE(", %p, &%p\n", Buffer, *ppsa); + + *(ULONG_PTR *)Buffer = *ppsa ? TRUE : FALSE; + Buffer += sizeof(ULONG_PTR); + if (*ppsa) + { + VARTYPE vt; + SAFEARRAY *psa = *ppsa; + ULONG ulCellCount = SAFEARRAY_GetCellCount(psa); + wireSAFEARRAY wiresa; + SF_TYPE sftype; + GUID guid; + + *(ULONG *)Buffer = psa->cDims; + Buffer += sizeof(ULONG); + wiresa = (wireSAFEARRAY)Buffer; + wiresa->cDims = psa->cDims; + wiresa->fFeatures = psa->fFeatures; + wiresa->cbElements = psa->cbElements; + + hr = SafeArrayGetVartype(psa, &vt); + if (FAILED(hr)) + RpcRaiseException(hr); + wiresa->cLocks = (USHORT)psa->cLocks | (vt << 16); + + Buffer += FIELD_OFFSET(struct _wireSAFEARRAY, uArrayStructs); + + sftype = SAFEARRAY_GetUnionType(psa); + *(ULONG *)Buffer = sftype; + Buffer += sizeof(ULONG); + + *(ULONG *)Buffer = ulCellCount; + Buffer += sizeof(ULONG); + *(ULONG_PTR *)Buffer = (ULONG_PTR)psa->pvData; + Buffer += sizeof(ULONG_PTR); + if (sftype == SF_HAVEIID) + { + SafeArrayGetIID(psa, &guid); + memcpy(Buffer, &guid, sizeof(guid)); + Buffer += sizeof(guid); + } + + memcpy(Buffer, psa->rgsabound, sizeof(psa->rgsabound[0]) * psa->cDims); + Buffer += sizeof(psa->rgsabound[0]) * psa->cDims; + + *(ULONG *)Buffer = ulCellCount; + Buffer += sizeof(ULONG); + + if (psa->pvData) + { + switch (sftype) + { + case SF_BSTR: + { + BSTR* lpBstr; + + for (lpBstr = (BSTR*)psa->pvData; ulCellCount; ulCellCount--, lpBstr++) + Buffer = BSTR_UserMarshal(pFlags, Buffer, lpBstr); + + break; + } + case SF_DISPATCH: + case SF_UNKNOWN: + case SF_HAVEIID: + FIXME("marshal interfaces\n"); + break; + case SF_VARIANT: + { + VARIANT* lpVariant; + + for (lpVariant = (VARIANT*)psa->pvData; ulCellCount; ulCellCount--, lpVariant++) + Buffer = VARIANT_UserMarshal(pFlags, Buffer, lpVariant); + + break; + } + case SF_RECORD: + { + IRecordInfo* pRecInfo = NULL; + + hr = SafeArrayGetRecordInfo(psa, &pRecInfo); + if (FAILED(hr)) + RpcRaiseException(hr); + + if (pRecInfo) + { + FIXME("write record info %p\n", pRecInfo); + + IRecordInfo_Release(pRecInfo); + } + break; + } + case SF_I1: + case SF_I2: + case SF_I4: + case SF_I8: + /* Just copy the data over */ + memcpy(Buffer, psa->pvData, ulCellCount * psa->cbElements); + Buffer += ulCellCount * psa->cbElements; + break; + default: + break; + } + } + + } + return Buffer; +} + +#define FADF_AUTOSETFLAGS (FADF_HAVEIID | FADF_RECORD | FADF_HAVEVARTYPE | \ + FADF_BSTR | FADF_UNKNOWN | FADF_DISPATCH | \ + FADF_VARIANT | FADF_CREATEVECTOR) + +unsigned char * WINAPI LPSAFEARRAY_UserUnmarshal(unsigned long *pFlags, unsigned char *Buffer, LPSAFEARRAY *ppsa) +{ + ULONG_PTR ptr; + wireSAFEARRAY wiresa; + ULONG cDims; + HRESULT hr; + SF_TYPE sftype; + ULONG cell_count; + GUID guid; + VARTYPE vt; + SAFEARRAYBOUND *wiresab; + + TRACE("("); dump_user_flags(pFlags); TRACE(", %p, %p\n", Buffer, ppsa); + + ptr = *(ULONG_PTR *)Buffer; + Buffer += sizeof(ULONG_PTR); + + if (!ptr) + { + *ppsa = NULL; + + TRACE("NULL safe array unmarshaled\n"); + + return Buffer; + } + + cDims = *(ULONG *)Buffer; + Buffer += sizeof(ULONG); + + wiresa = (wireSAFEARRAY)Buffer; + Buffer += FIELD_OFFSET(struct _wireSAFEARRAY, uArrayStructs); + + if (cDims != wiresa->cDims) + RpcRaiseException(RPC_S_INVALID_BOUND); + + /* FIXME: there should be a limit on how large cDims can be */ + + vt = HIWORD(wiresa->cLocks); + + sftype = *(ULONG *)Buffer; + Buffer += sizeof(ULONG); + + cell_count = *(ULONG *)Buffer; + Buffer += sizeof(ULONG); + ptr = *(ULONG_PTR *)Buffer; + Buffer += sizeof(ULONG_PTR); + if (sftype == SF_HAVEIID) + { + memcpy(&guid, Buffer, sizeof(guid)); + Buffer += sizeof(guid); + } + + wiresab = (SAFEARRAYBOUND *)Buffer; + Buffer += sizeof(wiresab[0]) * wiresa->cDims; + + *ppsa = SafeArrayCreateEx(vt, wiresa->cDims, wiresab, NULL); + if (!ppsa) + RpcRaiseException(E_OUTOFMEMORY); + + /* be careful about which flags we set since they could be a security + * risk */ + (*ppsa)->fFeatures = wiresa->fFeatures & ~(FADF_AUTOSETFLAGS); + /* FIXME: there should be a limit on how large wiresa->cbElements can be */ + (*ppsa)->cbElements = wiresa->cbElements; + (*ppsa)->cLocks = LOWORD(wiresa->cLocks); + + hr = SafeArrayAllocData(*ppsa); + if (FAILED(hr)) + RpcRaiseException(hr); + + if ((*(ULONG *)Buffer != cell_count) || (SAFEARRAY_GetCellCount(*ppsa) != cell_count)) + RpcRaiseException(RPC_S_INVALID_BOUND); + Buffer += sizeof(ULONG); + + if (ptr) + { + switch (sftype) + { + case SF_BSTR: + { + BSTR* lpBstr; + + for (lpBstr = (BSTR*)(*ppsa)->pvData; cell_count; cell_count--, lpBstr++) + Buffer = BSTR_UserUnmarshal(pFlags, Buffer, lpBstr); + + break; + } + case SF_DISPATCH: + case SF_UNKNOWN: + case SF_HAVEIID: + FIXME("marshal interfaces\n"); + break; + case SF_VARIANT: + { + VARIANT* lpVariant; + + for (lpVariant = (VARIANT*)(*ppsa)->pvData; cell_count; cell_count--, lpVariant++) + Buffer = VARIANT_UserUnmarshal(pFlags, Buffer, lpVariant); + + break; + } + case SF_RECORD: + { + FIXME("set record info\n"); + + break; + } + case SF_I1: + case SF_I2: + case SF_I4: + case SF_I8: + /* Just copy the data over */ + memcpy((*ppsa)->pvData, Buffer, cell_count * (*ppsa)->cbElements); + Buffer += cell_count * (*ppsa)->cbElements; + break; + default: + break; + } + } + + TRACE("safe array unmarshaled: %p\n", *ppsa); + + return Buffer; +} + +void WINAPI LPSAFEARRAY_UserFree(unsigned long *pFlags, LPSAFEARRAY *ppsa) +{ + TRACE("("); dump_user_flags(pFlags); TRACE(", &%p\n", *ppsa); + + SafeArrayDestroy(*ppsa); +} + /* IDispatch */ /* exactly how Invoke is marshalled is not very clear to me yet, * but the way I've done it seems to work for me */