From 6e95bfe85a0b7c1f0591ea436c391ef930a8d7dd Mon Sep 17 00:00:00 2001 From: Robert Shearman Date: Sat, 10 Jun 2006 12:31:45 +0100 Subject: [PATCH] rpcrt4: Introduce a new function, safe_multiply, which will raise an exception if a multiply overflows a 4-byte integer. This will protect the unmarshaling code against attacks specifying a large variance. Use this new function in the conformant string functions to harden them against attack. --- dlls/rpcrt4/ndr_marshall.c | 48 ++++++++++++++++++++++++++++---------- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/dlls/rpcrt4/ndr_marshall.c b/dlls/rpcrt4/ndr_marshall.c index 6143c0c32db..8fbb61f0d23 100644 --- a/dlls/rpcrt4/ndr_marshall.c +++ b/dlls/rpcrt4/ndr_marshall.c @@ -518,6 +518,19 @@ finish_conf: return pFormat+4; } +/* multiply two numbers together, raising an RPC_S_INVALID_BOUND exception if + * the result overflows 32-bits */ +static ULONG inline safe_multiply(ULONG a, ULONG b) +{ + ULONGLONG ret = (ULONGLONG)a * b; + if (ret > 0xffffffff) + { + RpcRaiseException(RPC_S_INVALID_BOUND); + return 0; + } + return ret; +} + /* * NdrConformantString: @@ -542,7 +555,7 @@ finish_conf: unsigned char *WINAPI NdrConformantStringMarshall(MIDL_STUB_MESSAGE *pStubMsg, unsigned char *pszMessage, PFORMAT_STRING pFormat) { - unsigned long esize; + ULONG esize, size; TRACE("(pStubMsg == ^%p, pszMessage == ^%p, pFormat == ^%p)\n", pStubMsg, pszMessage, pFormat); @@ -570,8 +583,9 @@ unsigned char *WINAPI NdrConformantStringMarshall(MIDL_STUB_MESSAGE *pStubMsg, WriteConformance(pStubMsg); WriteVariance(pStubMsg); - memcpy(pStubMsg->Buffer, pszMessage, pStubMsg->ActualCount*esize); /* the string itself */ - pStubMsg->Buffer += pStubMsg->ActualCount*esize; + size = safe_multiply(esize, pStubMsg->ActualCount); + memcpy(pStubMsg->Buffer, pszMessage, size); /* the string itself */ + pStubMsg->Buffer += size; STD_OVERFLOW_CHECK(pStubMsg); @@ -585,25 +599,35 @@ unsigned char *WINAPI NdrConformantStringMarshall(MIDL_STUB_MESSAGE *pStubMsg, void WINAPI NdrConformantStringBufferSize(PMIDL_STUB_MESSAGE pStubMsg, unsigned char* pMemory, PFORMAT_STRING pFormat) { + ULONG esize; + TRACE("(pStubMsg == ^%p, pMemory == ^%p, pFormat == ^%p)\n", pStubMsg, pMemory, pFormat); SizeConformance(pStubMsg); SizeVariance(pStubMsg); if (*pFormat == RPC_FC_C_CSTRING) { - /* we need + 1 octet for '\0' */ TRACE("string=%s\n", debugstr_a((char*)pMemory)); - pStubMsg->BufferLength += strlen((char*)pMemory) + 1; + pStubMsg->ActualCount = strlen((char*)pMemory)+1; + esize = 1; } else if (*pFormat == RPC_FC_C_WSTRING) { - /* we need + 2 octets for L'\0' */ TRACE("string=%s\n", debugstr_w((LPWSTR)pMemory)); - pStubMsg->BufferLength += strlenW((LPWSTR)pMemory)*2 + 2; + pStubMsg->ActualCount = strlenW((LPWSTR)pMemory)+1; + esize = 2; } else { ERR("Unhandled string type: %#x\n", *pFormat); /* FIXME: raise an exception */ + return; } + + if (pFormat[1] == RPC_FC_STRING_SIZED) + pFormat = ComputeConformance(pStubMsg, pMemory, pFormat + 2, 0); + else + pStubMsg->MaxCount = pStubMsg->ActualCount; + + pStubMsg->BufferLength += safe_multiply(esize, pStubMsg->ActualCount); } /************************************************************************ @@ -643,7 +667,7 @@ unsigned long WINAPI NdrConformantStringMemorySize( PMIDL_STUB_MESSAGE pStubMsg, unsigned char *WINAPI NdrConformantStringUnmarshall( PMIDL_STUB_MESSAGE pStubMsg, unsigned char** ppMemory, PFORMAT_STRING pFormat, unsigned char fMustAlloc ) { - unsigned long len, esize; + unsigned long size, esize; TRACE("(pStubMsg == ^%p, *pMemory == ^%p, pFormat == ^%p, fMustAlloc == %u)\n", pStubMsg, *ppMemory, pFormat, fMustAlloc); @@ -661,14 +685,14 @@ unsigned char *WINAPI NdrConformantStringUnmarshall( PMIDL_STUB_MESSAGE pStubMsg esize = 0; } - len = pStubMsg->ActualCount; + size = safe_multiply(esize, pStubMsg->ActualCount); if (fMustAlloc || !*ppMemory) - *ppMemory = NdrAllocate(pStubMsg, len*esize); + *ppMemory = NdrAllocate(pStubMsg, size); - memcpy(*ppMemory, pStubMsg->Buffer, len*esize); + memcpy(*ppMemory, pStubMsg->Buffer, size); - pStubMsg->Buffer += len*esize; + pStubMsg->Buffer += size; if (*pFormat == RPC_FC_C_CSTRING) { TRACE("string=%s\n", debugstr_a((char*)*ppMemory));