From 2e8a74d520311e60d70f40fd0af69a57c6d00ab6 Mon Sep 17 00:00:00 2001 From: Huw Davies Date: Sat, 29 Apr 2006 14:40:28 +0100 Subject: [PATCH] oleaut32: Fix BSTR marshaling to be wire compatible with Windows. --- dlls/oleaut32/tests/usrmarshal.c | 68 +++++++++++++++++++++++++++ dlls/oleaut32/usrmarshal.c | 81 ++++++++++++++++++++------------ 2 files changed, 120 insertions(+), 29 deletions(-) diff --git a/dlls/oleaut32/tests/usrmarshal.c b/dlls/oleaut32/tests/usrmarshal.c index b5ca2fb3c08..5d793159f14 100644 --- a/dlls/oleaut32/tests/usrmarshal.c +++ b/dlls/oleaut32/tests/usrmarshal.c @@ -30,6 +30,7 @@ /* doesn't work on Windows due to needing more of the * MIDL_STUB_MESSAGE structure to be filled out */ #define LPSAFEARRAY_UNMARSHAL_WORKS 0 +#define BSTR_UNMARSHAL_WORKS 0 static void test_marshal_LPSAFEARRAY(void) { @@ -115,11 +116,78 @@ static void test_marshal_LPSAFEARRAY(void) HeapFree(GetProcessHeap(), 0, buffer); } +static void test_marshal_BSTR(void) +{ + unsigned long size; + MIDL_STUB_MESSAGE stubMsg = { 0 }; + USER_MARSHAL_CB umcb = { 0 }; + unsigned char *buffer; + BSTR b, b2; + WCHAR str[] = {'m','a','r','s','h','a','l',' ','t','e','s','t','1',0}; + DWORD *wireb, len; + + umcb.Flags = MAKELONG(MSHCTX_DIFFERENTMACHINE, NDR_LOCAL_DATA_REPRESENTATION); + umcb.pReserve = NULL; + umcb.pStubMsg = &stubMsg; + + b = SysAllocString(str); + len = SysStringLen(b); + ok(len == 13, "get %ld\n", len); + + /* BSTRs are DWORD aligned */ + size = BSTR_UserSize(&umcb.Flags, 1, &b); + ok(size == 42, "size %ld\n", size); + + size = BSTR_UserSize(&umcb.Flags, 0, &b); + ok(size == 38, "size %ld\n", size); + + buffer = HeapAlloc(GetProcessHeap(), 0, size); + BSTR_UserMarshal(&umcb.Flags, buffer, &b); + wireb = (DWORD*)buffer; + + ok(*wireb == len, "wv[0] %08lx\n", *wireb); + wireb++; + ok(*wireb == len * 2, "wv[1] %08lx\n", *wireb); + wireb++; + ok(*wireb == len, "wv[2] %08lx\n", *wireb); + wireb++; + ok(!memcmp(wireb, str, len * 2), "strings differ\n"); + + if (BSTR_UNMARSHAL_WORKS) + { + b2 = NULL; + BSTR_UserUnmarshal(&umcb.Flags, buffer, &b2); + ok(b2 != NULL, "NULL LPSAFEARRAY didn't unmarshal\n"); + ok(!memcmp(b, b2, (len + 1) * 2), "strings differ\n"); + BSTR_UserFree(&umcb.Flags, &b2); + } + + HeapFree(GetProcessHeap(), 0, buffer); + SysFreeString(b); + + b = NULL; + size = BSTR_UserSize(&umcb.Flags, 0, &b); + ok(size == 12, "size %ld\n", size); + + buffer = HeapAlloc(GetProcessHeap(), 0, size); + BSTR_UserMarshal(&umcb.Flags, buffer, &b); + wireb = (DWORD*)buffer; + ok(*wireb == 0, "wv[0] %08lx\n", *wireb); + wireb++; + ok(*wireb == 0xffffffff, "wv[1] %08lx\n", *wireb); + wireb++; + ok(*wireb == 0, "wv[2] %08lx\n", *wireb); + + HeapFree(GetProcessHeap(), 0, buffer); + +} + START_TEST(usrmarshal) { CoInitialize(NULL); test_marshal_LPSAFEARRAY(); + test_marshal_BSTR(); CoUninitialize(); } diff --git a/dlls/oleaut32/usrmarshal.c b/dlls/oleaut32/usrmarshal.c index 9140c9ee9e5..05aef80450b 100644 --- a/dlls/oleaut32/usrmarshal.c +++ b/dlls/oleaut32/usrmarshal.c @@ -140,50 +140,73 @@ void WINAPI CLEANLOCALSTORAGE_UserFree(unsigned long *pFlags, CLEANLOCALSTORAGE /* BSTR */ +typedef struct +{ + DWORD len; /* No. of chars not including trailing '\0' */ + DWORD byte_len; /* len * 2 or 0xffffffff if len == 0 */ + DWORD len2; /* == len */ +} bstr_wire_t; + unsigned long WINAPI BSTR_UserSize(unsigned long *pFlags, unsigned long Start, BSTR *pstr) { - TRACE("(%lx,%ld,%p) => %p\n", *pFlags, Start, pstr, *pstr); - if (*pstr) TRACE("string=%s\n", debugstr_w(*pstr)); - Start += sizeof(FLAGGED_WORD_BLOB) + sizeof(OLECHAR) * (SysStringLen(*pstr) - 1); - TRACE("returning %ld\n", Start); - return Start; + TRACE("(%lx,%ld,%p) => %p\n", *pFlags, Start, pstr, *pstr); + if (*pstr) TRACE("string=%s\n", debugstr_w(*pstr)); + ALIGN_LENGTH(Start, 3); + Start += sizeof(bstr_wire_t) + sizeof(OLECHAR) * (SysStringLen(*pstr)); + TRACE("returning %ld\n", Start); + return Start; } unsigned char * WINAPI BSTR_UserMarshal(unsigned long *pFlags, unsigned char *Buffer, BSTR *pstr) { - wireBSTR str = (wireBSTR)Buffer; + bstr_wire_t *header; + TRACE("(%lx,%p,%p) => %p\n", *pFlags, Buffer, pstr, *pstr); + if (*pstr) TRACE("string=%s\n", debugstr_w(*pstr)); - TRACE("(%lx,%p,%p) => %p\n", *pFlags, Buffer, pstr, *pstr); - if (*pstr) TRACE("string=%s\n", debugstr_w(*pstr)); - str->fFlags = 0; - str->clSize = SysStringLen(*pstr); - if (str->clSize) - memcpy(&str->asData, *pstr, sizeof(OLECHAR) * str->clSize); - return Buffer + sizeof(FLAGGED_WORD_BLOB) + sizeof(OLECHAR) * (str->clSize - 1); + ALIGN_POINTER(Buffer, 3); + header = (bstr_wire_t*)Buffer; + header->len = header->len2 = SysStringLen(*pstr); + if (header->len) + { + header->byte_len = header->len * sizeof(OLECHAR); + memcpy(header + 1, *pstr, header->byte_len); + } + else + header->byte_len = 0xffffffff; /* special case for an empty string */ + + return Buffer + sizeof(*header) + sizeof(OLECHAR) * header->len; } unsigned char * WINAPI BSTR_UserUnmarshal(unsigned long *pFlags, unsigned char *Buffer, BSTR *pstr) { - wireBSTR str = (wireBSTR)Buffer; - TRACE("(%lx,%p,%p) => %p\n", *pFlags, Buffer, pstr, *pstr); - if (str->clSize) { - SysReAllocStringLen(pstr, (OLECHAR*)&str->asData, str->clSize); - } - else if (*pstr) { - SysFreeString(*pstr); - *pstr = NULL; - } - if (*pstr) TRACE("string=%s\n", debugstr_w(*pstr)); - return Buffer + sizeof(FLAGGED_WORD_BLOB) + sizeof(OLECHAR) * (str->clSize - 1); + bstr_wire_t *header; + TRACE("(%lx,%p,%p) => %p\n", *pFlags, Buffer, pstr, *pstr); + + ALIGN_POINTER(Buffer, 3); + header = (bstr_wire_t*)Buffer; + if(header->len != header->len2) + FIXME("len %08lx != len2 %08lx\n", header->len, header->len2); + + if(header->len) + SysReAllocStringLen(pstr, (OLECHAR*)(header + 1), header->len); + else if (*pstr) + { + SysFreeString(*pstr); + *pstr = NULL; + } + + if (*pstr) TRACE("string=%s\n", debugstr_w(*pstr)); + return Buffer + sizeof(*header) + sizeof(OLECHAR) * header->len; } void WINAPI BSTR_UserFree(unsigned long *pFlags, BSTR *pstr) { - TRACE("(%lx,%p) => %p\n", *pFlags, pstr, *pstr); - if (*pstr) { - SysFreeString(*pstr); - *pstr = NULL; - } + TRACE("(%lx,%p) => %p\n", *pFlags, pstr, *pstr); + if (*pstr) + { + SysFreeString(*pstr); + *pstr = NULL; + } } /* VARIANT */