diff --git a/dlls/oleaut32/tests/usrmarshal.c b/dlls/oleaut32/tests/usrmarshal.c index b0f3808ebc0..4433f503bd3 100644 --- a/dlls/oleaut32/tests/usrmarshal.c +++ b/dlls/oleaut32/tests/usrmarshal.c @@ -227,20 +227,20 @@ static void test_marshal_LPSAFEARRAY(void) static void check_bstr(void *buffer, BSTR b) { DWORD *wireb = buffer; - DWORD len = SysStringLen(b); + DWORD len = SysStringByteLen(b); - ok(*wireb == len, "wv[0] %08lx\n", *wireb); + ok(*wireb == (len + 1) / 2, "wv[0] %08lx\n", *wireb); wireb++; - if(len) - ok(*wireb == len * 2, "wv[1] %08lx\n", *wireb); + if(b) + ok(*wireb == len, "wv[1] %08lx\n", *wireb); else ok(*wireb == 0xffffffff, "wv[1] %08lx\n", *wireb); wireb++; - ok(*wireb == len, "wv[2] %08lx\n", *wireb); + ok(*wireb == (len + 1) / 2, "wv[2] %08lx\n", *wireb); if(len) { wireb++; - ok(!memcmp(wireb, b, len * 2), "strings differ\n"); + ok(!memcmp(wireb, b, (len + 1) & ~1), "strings differ\n"); } return; } @@ -250,7 +250,7 @@ static void test_marshal_BSTR(void) unsigned long size; MIDL_STUB_MESSAGE stubMsg = { 0 }; USER_MARSHAL_CB umcb = { 0 }; - unsigned char *buffer; + unsigned char *buffer, *next; BSTR b, b2; WCHAR str[] = {'m','a','r','s','h','a','l',' ','t','e','s','t','1',0}; DWORD len; @@ -271,14 +271,16 @@ static void test_marshal_BSTR(void) ok(size == 38, "size %ld\n", size); buffer = HeapAlloc(GetProcessHeap(), 0, size); - BSTR_UserMarshal(&umcb.Flags, buffer, &b); + next = BSTR_UserMarshal(&umcb.Flags, buffer, &b); + ok(next == buffer + size, "got %p expect %p\n", next, buffer + size); check_bstr(buffer, b); if (BSTR_UNMARSHAL_WORKS) { b2 = NULL; - BSTR_UserUnmarshal(&umcb.Flags, buffer, &b2); - ok(b2 != NULL, "NULL LPSAFEARRAY didn't unmarshal\n"); + next = BSTR_UserUnmarshal(&umcb.Flags, buffer, &b2); + ok(next == buffer + size, "got %p expect %p\n", next, buffer + size); + ok(b2 != NULL, "BSTR didn't unmarshal\n"); ok(!memcmp(b, b2, (len + 1) * 2), "strings differ\n"); BSTR_UserFree(&umcb.Flags, &b2); } @@ -291,11 +293,75 @@ static void test_marshal_BSTR(void) ok(size == 12, "size %ld\n", size); buffer = HeapAlloc(GetProcessHeap(), 0, size); - BSTR_UserMarshal(&umcb.Flags, buffer, &b); + next = BSTR_UserMarshal(&umcb.Flags, buffer, &b); + ok(next == buffer + size, "got %p expect %p\n", next, buffer + size); check_bstr(buffer, b); + if (BSTR_UNMARSHAL_WORKS) + { + b2 = NULL; + next = BSTR_UserUnmarshal(&umcb.Flags, buffer, &b2); + ok(next == buffer + size, "got %p expect %p\n", next, buffer + size); + ok(b2 == NULL, "NULL BSTR didn't unmarshal\n"); + BSTR_UserFree(&umcb.Flags, &b2); + } HeapFree(GetProcessHeap(), 0, buffer); + b = SysAllocStringByteLen("abc", 3); + *(((char*)b) + 3) = 'd'; + len = SysStringLen(b); + ok(len == 1, "get %ld\n", len); + len = SysStringByteLen(b); + ok(len == 3, "get %ld\n", len); + + size = BSTR_UserSize(&umcb.Flags, 0, &b); + ok(size == 16, "size %ld\n", size); + + buffer = HeapAlloc(GetProcessHeap(), 0, size); + memset(buffer, 0xcc, size); + next = BSTR_UserMarshal(&umcb.Flags, buffer, &b); + ok(next == buffer + size, "got %p expect %p\n", next, buffer + size); + check_bstr(buffer, b); + ok(buffer[15] == 'd', "buffer[15] %02x\n", buffer[15]); + + if (BSTR_UNMARSHAL_WORKS) + { + b2 = NULL; + next = BSTR_UserUnmarshal(&umcb.Flags, buffer, &b2); + ok(next == buffer + size, "got %p expect %p\n", next, buffer + size); + ok(b2 != NULL, "BSTR didn't unmarshal\n"); + ok(!memcmp(b, b2, len), "strings differ\n"); + BSTR_UserFree(&umcb.Flags, &b2); + } + HeapFree(GetProcessHeap(), 0, buffer); + SysFreeString(b); + + b = SysAllocStringByteLen("", 0); + len = SysStringLen(b); + ok(len == 0, "get %ld\n", len); + len = SysStringByteLen(b); + ok(len == 0, "get %ld\n", len); + + size = BSTR_UserSize(&umcb.Flags, 0, &b); + ok(size == 12, "size %ld\n", size); + + buffer = HeapAlloc(GetProcessHeap(), 0, size); + next = BSTR_UserMarshal(&umcb.Flags, buffer, &b); + ok(next == buffer + size, "got %p expect %p\n", next, buffer + size); + check_bstr(buffer, b); + + if (BSTR_UNMARSHAL_WORKS) + { + b2 = NULL; + next = BSTR_UserUnmarshal(&umcb.Flags, buffer, &b2); + ok(next == buffer + size, "got %p expect %p\n", next, buffer + size); + ok(b2 != NULL, "NULL LPSAFEARRAY didn't unmarshal\n"); + len = SysStringByteLen(b2); + ok(len == 0, "byte len %ld\n", len); + BSTR_UserFree(&umcb.Flags, &b2); + } + HeapFree(GetProcessHeap(), 0, buffer); + SysFreeString(b); } static void check_variant_header(DWORD *wirev, VARIANT *v, unsigned long size) diff --git a/dlls/oleaut32/usrmarshal.c b/dlls/oleaut32/usrmarshal.c index 002df909857..dd397c0967a 100644 --- a/dlls/oleaut32/usrmarshal.c +++ b/dlls/oleaut32/usrmarshal.c @@ -152,7 +152,7 @@ unsigned long WINAPI BSTR_UserSize(unsigned long *pFlags, unsigned long Start, B TRACE("(%lx,%ld,%p) => %p\n", *pFlags, Start, pstr, *pstr); if (*pstr) TRACE("string=%s\n", debugstr_w(*pstr)); ALIGN_LENGTH(Start, 3); - Start += sizeof(bstr_wire_t) + sizeof(OLECHAR) * (SysStringLen(*pstr)); + Start += sizeof(bstr_wire_t) + ((SysStringByteLen(*pstr) + 1) & ~1); TRACE("returning %ld\n", Start); return Start; } @@ -160,19 +160,21 @@ unsigned long WINAPI BSTR_UserSize(unsigned long *pFlags, unsigned long Start, B unsigned char * WINAPI BSTR_UserMarshal(unsigned long *pFlags, unsigned char *Buffer, BSTR *pstr) { bstr_wire_t *header; + DWORD len = SysStringByteLen(*pstr); + TRACE("(%lx,%p,%p) => %p\n", *pFlags, Buffer, pstr, *pstr); if (*pstr) TRACE("string=%s\n", debugstr_w(*pstr)); ALIGN_POINTER(Buffer, 3); header = (bstr_wire_t*)Buffer; - header->len = header->len2 = SysStringLen(*pstr); - if (header->len) + header->len = header->len2 = (len + 1) / 2; + if (*pstr) { - header->byte_len = header->len * sizeof(OLECHAR); - memcpy(header + 1, *pstr, header->byte_len); + header->byte_len = len; + memcpy(header + 1, *pstr, header->len * 2); } else - header->byte_len = 0xffffffff; /* special case for an empty string */ + header->byte_len = 0xffffffff; /* special case for a null bstr */ return Buffer + sizeof(*header) + sizeof(OLECHAR) * header->len; } @@ -187,14 +189,15 @@ unsigned char * WINAPI BSTR_UserUnmarshal(unsigned long *pFlags, unsigned char * if(header->len != header->len2) FIXME("len %08lx != len2 %08lx\n", header->len, header->len2); - if(header->len) - SysReAllocStringLen(pstr, (OLECHAR*)(header + 1), header->len); - else if (*pstr) + if(*pstr) { SysFreeString(*pstr); *pstr = NULL; } + if(header->byte_len != 0xffffffff) + *pstr = SysAllocStringByteLen((char*)(header + 1), header->byte_len); + if (*pstr) TRACE("string=%s\n", debugstr_w(*pstr)); return Buffer + sizeof(*header) + sizeof(OLECHAR) * header->len; }