diff --git a/dlls/oleaut32/tests/usrmarshal.c b/dlls/oleaut32/tests/usrmarshal.c index e3f0065f38d..9bc282d0fa2 100644 --- a/dlls/oleaut32/tests/usrmarshal.c +++ b/dlls/oleaut32/tests/usrmarshal.c @@ -46,7 +46,16 @@ static inline SF_TYPE get_union_type(SAFEARRAY *psa) hr = SafeArrayGetVartype(psa, &vt); if (FAILED(hr)) - return 0; + { + switch(psa->cbElements) + { + case 1: vt = VT_I1; break; + case 2: vt = VT_I2; break; + case 4: vt = VT_I4; break; + case 8: vt = VT_I8; break; + default: return 0; + } + } if (psa->fFeatures & FADF_HAVEIID) return SF_HAVEIID; @@ -111,7 +120,9 @@ static void check_safearray(void *buffer, LPSAFEARRAY lpsa) return; } - SafeArrayGetVartype(lpsa, &vt); + if(FAILED(SafeArrayGetVartype(lpsa, &vt))) + vt = 0; + sftype = get_union_type(lpsa); cell_count = get_cell_count(lpsa); @@ -159,6 +170,8 @@ static void test_marshal_LPSAFEARRAY(void) SAFEARRAYBOUND sab; MIDL_STUB_MESSAGE stubMsg = { 0 }; USER_MARSHAL_CB umcb = { 0 }; + HRESULT hr; + VARTYPE vt; umcb.Flags = MAKELONG(MSHCTX_DIFFERENTMACHINE, NDR_LOCAL_DATA_REPRESENTATION); umcb.pReserve = NULL; @@ -228,6 +241,27 @@ static void test_marshal_LPSAFEARRAY(void) HeapFree(GetProcessHeap(), 0, buffer); SafeArrayDestroy(lpsa); + + /* VARTYPE-less arrays can be marshaled if cbElements is 1,2,4 or 8 as type SF_In */ + hr = SafeArrayAllocDescriptor(1, &lpsa); + ok(hr == S_OK, "saad failed %08x\n", hr); + lpsa->cbElements = 8; + lpsa->rgsabound[0].lLbound = 2; + lpsa->rgsabound[0].cElements = 48; + hr = SafeArrayAllocData(lpsa); + ok(hr == S_OK, "saad failed %08x\n", hr); + + hr = SafeArrayGetVartype(lpsa, &vt); + ok(hr == E_INVALIDARG, "ret %08x\n", hr); + + size = LPSAFEARRAY_UserSize(&umcb.Flags, 0, &lpsa); + ok(size == 432, "size %ld\n", size); + buffer = (unsigned char *)HeapAlloc(GetProcessHeap(), 0, size); + LPSAFEARRAY_UserMarshal(&umcb.Flags, buffer, &lpsa); + check_safearray(buffer, lpsa); + HeapFree(GetProcessHeap(), 0, buffer); + SafeArrayDestroyData(lpsa); + SafeArrayDestroyDescriptor(lpsa); } static void check_bstr(void *buffer, BSTR b) diff --git a/dlls/oleaut32/usrmarshal.c b/dlls/oleaut32/usrmarshal.c index 02a3ed9fa8e..71d7d6b1593 100644 --- a/dlls/oleaut32/usrmarshal.c +++ b/dlls/oleaut32/usrmarshal.c @@ -696,7 +696,17 @@ static inline SF_TYPE SAFEARRAY_GetUnionType(SAFEARRAY *psa) hr = SafeArrayGetVartype(psa, &vt); if (FAILED(hr)) - RpcRaiseException(hr); + { + switch(psa->cbElements) + { + case 1: vt = VT_I1; break; + case 2: vt = VT_I2; break; + case 4: vt = VT_I4; break; + case 8: vt = VT_I8; break; + default: + RpcRaiseException(hr); + } + } if (psa->fFeatures & FADF_HAVEIID) return SF_HAVEIID; @@ -846,8 +856,8 @@ unsigned char * WINAPI LPSAFEARRAY_UserMarshal(ULONG *pFlags, unsigned char *Buf wiresa->cbElements = psa->cbElements; hr = SafeArrayGetVartype(psa, &vt); - if (FAILED(hr)) - RpcRaiseException(hr); + if (FAILED(hr)) vt = 0; + wiresa->cLocks = (USHORT)psa->cLocks | (vt << 16); Buffer += FIELD_OFFSET(struct _wireSAFEARRAY, uArrayStructs); @@ -996,7 +1006,14 @@ unsigned char * WINAPI LPSAFEARRAY_UserUnmarshal(ULONG *pFlags, unsigned char *B wiresab = (SAFEARRAYBOUND *)Buffer; Buffer += sizeof(wiresab[0]) * wiresa->cDims; - *ppsa = SafeArrayCreateEx(vt, wiresa->cDims, wiresab, NULL); + if(vt) + *ppsa = SafeArrayCreateEx(vt, wiresa->cDims, wiresab, NULL); + else + { + SafeArrayAllocDescriptor(wiresa->cDims, ppsa); + if(*ppsa) + memcpy((*ppsa)->rgsabound, wiresab, sizeof(SAFEARRAYBOUND) * wiresa->cDims); + } if (!*ppsa) RpcRaiseException(E_OUTOFMEMORY);