diff --git a/dlls/combase/combase.c b/dlls/combase/combase.c index bcf9912cbc1..a757407c71e 100644 --- a/dlls/combase/combase.c +++ b/dlls/combase/combase.c @@ -20,14 +20,179 @@ #define COBJMACROS #define NONAMELESSUNION +#include "ntstatus.h" +#define WIN32_NO_STATUS #define USE_COM_CONTEXT_DEF #include "objbase.h" #include "oleauto.h" +#include "winternl.h" #include "wine/debug.h" WINE_DEFAULT_DEBUG_CHANNEL(ole); +#define CHARS_IN_GUID 39 + +static NTSTATUS create_key(HKEY *retkey, ACCESS_MASK access, OBJECT_ATTRIBUTES *attr) +{ + NTSTATUS status = NtCreateKey((HANDLE *)retkey, access, attr, 0, NULL, 0, NULL); + + if (status == STATUS_OBJECT_NAME_NOT_FOUND) + { + HANDLE subkey, root = attr->RootDirectory; + WCHAR *buffer = attr->ObjectName->Buffer; + DWORD attrs, pos = 0, i = 0, len = attr->ObjectName->Length / sizeof(WCHAR); + UNICODE_STRING str; + + while (i < len && buffer[i] != '\\') i++; + if (i == len) return status; + + attrs = attr->Attributes; + attr->ObjectName = &str; + + while (i < len) + { + str.Buffer = buffer + pos; + str.Length = (i - pos) * sizeof(WCHAR); + status = NtCreateKey(&subkey, access, attr, 0, NULL, 0, NULL); + if (attr->RootDirectory != root) NtClose(attr->RootDirectory); + if (status) return status; + attr->RootDirectory = subkey; + while (i < len && buffer[i] == '\\') i++; + pos = i; + while (i < len && buffer[i] != '\\') i++; + } + str.Buffer = buffer + pos; + str.Length = (i - pos) * sizeof(WCHAR); + attr->Attributes = attrs; + status = NtCreateKey((HANDLE *)retkey, access, attr, 0, NULL, 0, NULL); + if (attr->RootDirectory != root) NtClose(attr->RootDirectory); + } + return status; +} + +static HKEY classes_root_hkey; + +static HKEY create_classes_root_hkey(DWORD access) +{ + HKEY hkey, ret = 0; + OBJECT_ATTRIBUTES attr; + UNICODE_STRING name; + + attr.Length = sizeof(attr); + attr.RootDirectory = 0; + attr.ObjectName = &name; + attr.Attributes = 0; + attr.SecurityDescriptor = NULL; + attr.SecurityQualityOfService = NULL; + RtlInitUnicodeString(&name, L"\\Registry\\Machine\\Software\\Classes"); + + if (create_key( &hkey, access, &attr )) return 0; + TRACE( "%s -> %p\n", debugstr_w(attr.ObjectName->Buffer), hkey ); + + if (!(access & KEY_WOW64_64KEY)) + { + if (!(ret = InterlockedCompareExchangePointer( (void **)&classes_root_hkey, hkey, 0 ))) + ret = hkey; + else + NtClose( hkey ); /* somebody beat us to it */ + } + else + ret = hkey; + return ret; +} + +static HKEY get_classes_root_hkey(HKEY hkey, REGSAM access); + +static LSTATUS create_classes_key(HKEY hkey, const WCHAR *name, REGSAM access, HKEY *retkey) +{ + OBJECT_ATTRIBUTES attr; + UNICODE_STRING nameW; + + if (!(hkey = get_classes_root_hkey(hkey, access))) + return ERROR_INVALID_HANDLE; + + attr.Length = sizeof(attr); + attr.RootDirectory = hkey; + attr.ObjectName = &nameW; + attr.Attributes = 0; + attr.SecurityDescriptor = NULL; + attr.SecurityQualityOfService = NULL; + RtlInitUnicodeString( &nameW, name ); + + return RtlNtStatusToDosError(create_key(retkey, access, &attr)); +} + +static HKEY get_classes_root_hkey(HKEY hkey, REGSAM access) +{ + HKEY ret = hkey; + const BOOL is_win64 = sizeof(void*) > sizeof(int); + const BOOL force_wow32 = is_win64 && (access & KEY_WOW64_32KEY); + + if (hkey == HKEY_CLASSES_ROOT && + ((access & KEY_WOW64_64KEY) || !(ret = classes_root_hkey))) + ret = create_classes_root_hkey(MAXIMUM_ALLOWED | (access & KEY_WOW64_64KEY)); + if (force_wow32 && ret && ret == classes_root_hkey) + { + access &= ~KEY_WOW64_32KEY; + if (create_classes_key(classes_root_hkey, L"Wow6432Node", access, &hkey)) + return 0; + ret = hkey; + } + + return ret; +} + +static LSTATUS open_classes_key(HKEY hkey, const WCHAR *name, REGSAM access, HKEY *retkey) +{ + OBJECT_ATTRIBUTES attr; + UNICODE_STRING nameW; + + if (!(hkey = get_classes_root_hkey(hkey, access))) + return ERROR_INVALID_HANDLE; + + attr.Length = sizeof(attr); + attr.RootDirectory = hkey; + attr.ObjectName = &nameW; + attr.Attributes = 0; + attr.SecurityDescriptor = NULL; + attr.SecurityQualityOfService = NULL; + RtlInitUnicodeString( &nameW, name ); + + return RtlNtStatusToDosError(NtOpenKey((HANDLE *)retkey, access, &attr)); +} + +static HRESULT open_key_for_clsid(REFCLSID clsid, const WCHAR *keyname, REGSAM access, HKEY *subkey) +{ + static const WCHAR clsidW[] = L"CLSID\\"; + WCHAR path[CHARS_IN_GUID + ARRAY_SIZE(clsidW) - 1]; + LONG res; + HKEY key; + + lstrcpyW(path, clsidW); + StringFromGUID2(clsid, path + lstrlenW(clsidW), CHARS_IN_GUID); + res = open_classes_key(HKEY_CLASSES_ROOT, path, keyname ? KEY_READ : access, &key); + if (res == ERROR_FILE_NOT_FOUND) + return REGDB_E_CLASSNOTREG; + else if (res != ERROR_SUCCESS) + return REGDB_E_READREGDB; + + if (!keyname) + { + *subkey = key; + return S_OK; + } + + res = open_classes_key(key, keyname, access, subkey); + RegCloseKey(key); + if (res == ERROR_FILE_NOT_FOUND) + return REGDB_E_KEYMISSING; + else if (res != ERROR_SUCCESS) + return REGDB_E_READREGDB; + + return S_OK; +} + /*********************************************************************** * FreePropVariantArray (combase.@) */ @@ -628,6 +793,44 @@ HRESULT WINAPI CoGetActivationState(GUID guid, DWORD arg2, DWORD *arg3) return E_NOTIMPL; } +/****************************************************************************** + * CoGetTreatAsClass (combase.@) + */ +HRESULT WINAPI CoGetTreatAsClass(REFCLSID clsidOld, CLSID *clsidNew) +{ + WCHAR buffW[CHARS_IN_GUID]; + LONG len = sizeof(buffW); + HRESULT hr = S_OK; + HKEY hkey = NULL; + + TRACE("%s, %p.\n", debugstr_guid(clsidOld), clsidNew); + + if (!clsidOld || !clsidNew) + return E_INVALIDARG; + + *clsidNew = *clsidOld; + + hr = open_key_for_clsid(clsidOld, L"TreatAs", KEY_READ, &hkey); + if (FAILED(hr)) + { + hr = S_FALSE; + goto done; + } + + if (RegQueryValueW(hkey, NULL, buffW, &len)) + { + hr = S_FALSE; + goto done; + } + + hr = CLSIDFromString(buffW, clsidNew); + if (FAILED(hr)) + ERR("Failed to get CLSID from string %s, hr %#x.\n", debugstr_w(buffW), hr); +done: + if (hkey) RegCloseKey(hkey); + return hr; +} + static void init_multi_qi(DWORD count, MULTI_QI *mqi, HRESULT hr) { ULONG i; diff --git a/dlls/combase/combase.spec b/dlls/combase/combase.spec index b8e374924aa..497af6da8c0 100644 --- a/dlls/combase/combase.spec +++ b/dlls/combase/combase.spec @@ -119,7 +119,7 @@ @ stdcall CoGetStandardMarshal(ptr ptr long ptr long ptr) ole32.CoGetStandardMarshal @ stub CoGetStdMarshalEx @ stub CoGetSystemSecurityPermissions -@ stdcall CoGetTreatAsClass(ptr ptr) ole32.CoGetTreatAsClass +@ stdcall CoGetTreatAsClass(ptr ptr) @ stdcall CoImpersonateClient() @ stdcall CoIncrementMTAUsage(ptr) ole32.CoIncrementMTAUsage @ stdcall CoInitializeEx(ptr long) ole32.CoInitializeEx diff --git a/dlls/ole32/compobj.c b/dlls/ole32/compobj.c index cc64d7c036f..1f6a952f15a 100644 --- a/dlls/ole32/compobj.c +++ b/dlls/ole32/compobj.c @@ -3707,56 +3707,6 @@ done: return res; } -/****************************************************************************** - * CoGetTreatAsClass [OLE32.@] - * - * Gets the TreatAs value of a class. - * - * PARAMS - * clsidOld [I] Class to get the TreatAs value of. - * clsidNew [I] The class the clsidOld should be treated as. - * - * RETURNS - * Success: S_OK. - * Failure: HRESULT code. - * - * SEE ALSO - * CoSetTreatAsClass - */ -HRESULT WINAPI CoGetTreatAsClass(REFCLSID clsidOld, LPCLSID clsidNew) -{ - static const WCHAR wszTreatAs[] = {'T','r','e','a','t','A','s',0}; - HKEY hkey = NULL; - WCHAR szClsidNew[CHARS_IN_GUID]; - HRESULT res = S_OK; - LONG len = sizeof(szClsidNew); - - TRACE("(%s,%p)\n", debugstr_guid(clsidOld), clsidNew); - - if (!clsidOld || !clsidNew) - return E_INVALIDARG; - - *clsidNew = *clsidOld; /* copy over old value */ - - res = COM_OpenKeyForCLSID(clsidOld, wszTreatAs, KEY_READ, &hkey); - if (FAILED(res)) - { - res = S_FALSE; - goto done; - } - if (RegQueryValueW(hkey, NULL, szClsidNew, &len)) - { - res = S_FALSE; - goto done; - } - res = CLSIDFromString(szClsidNew,clsidNew); - if (FAILED(res)) - ERR("Failed CLSIDFromStringA(%s), hres 0x%08x\n", debugstr_w(szClsidNew), res); -done: - if (hkey) RegCloseKey(hkey); - return res; -} - /****************************************************************************** * CoGetCurrentProcess [OLE32.@] */ diff --git a/dlls/ole32/ole32.spec b/dlls/ole32/ole32.spec index 550b3149d49..e51b2b5887d 100644 --- a/dlls/ole32/ole32.spec +++ b/dlls/ole32/ole32.spec @@ -46,7 +46,7 @@ @ stdcall CoGetStandardMarshal(ptr ptr long ptr long ptr) @ stdcall CoGetState(ptr) @ stub CoGetTIDFromIPID -@ stdcall CoGetTreatAsClass(ptr ptr) +@ stdcall CoGetTreatAsClass(ptr ptr) combase.CoGetTreatAsClass @ stdcall CoImpersonateClient() combase.CoImpersonateClient @ stdcall CoIncrementMTAUsage(ptr) @ stdcall CoInitialize(ptr)