diff --git a/dlls/scrrun/dictionary.c b/dlls/scrrun/dictionary.c index 958d7fc98d8..5c1893ac60c 100644 --- a/dlls/scrrun/dictionary.c +++ b/dlls/scrrun/dictionary.c @@ -32,9 +32,20 @@ #include "wine/debug.h" #include "wine/unicode.h" +#include "wine/list.h" WINE_DEFAULT_DEBUG_CHANNEL(scrrun); +#define BUCKET_COUNT 509 + +struct keyitem_pair { + struct list entry; + DWORD bucket; + DWORD hash; + VARIANT key; + VARIANT item; +}; + typedef struct { IDictionary IDictionary_iface; @@ -42,6 +53,8 @@ typedef struct CompareMethod method; LONG count; + struct list pairs; + struct keyitem_pair *buckets[BUCKET_COUNT]; } dictionary; static inline dictionary *impl_from_IDictionary(IDictionary *iface) @@ -49,6 +62,133 @@ static inline dictionary *impl_from_IDictionary(IDictionary *iface) return CONTAINING_RECORD(iface, dictionary, IDictionary_iface); } +static inline struct keyitem_pair *get_bucket_head(const dictionary *dict, DWORD hash) +{ + return dict->buckets[hash % BUCKET_COUNT]; +} + +static inline BOOL is_string_key(const VARIANT *key) +{ + return V_VT(key) == VT_BSTR || V_VT(key) == (VT_BSTR|VT_BYREF); +} + +/* Only for VT_BSTR or VT_BSTR|VT_BYREF types */ +static inline WCHAR *get_key_strptr(const VARIANT *key) +{ + if (V_VT(key) == VT_BSTR) + return V_BSTR(key); + + if (V_BSTRREF(key)) + return *V_BSTRREF(key); + + return NULL; +} + +/* should be used only when both keys are of string type, it's not checked */ +static inline int strcmp_key(const dictionary *dict, const VARIANT *key1, const VARIANT *key2) +{ + const WCHAR *str1, *str2; + + str1 = get_key_strptr(key1); + str2 = get_key_strptr(key2); + return dict->method == BinaryCompare ? strcmpW(str1, str2) : strcmpiW(str1, str2); +} + +static BOOL is_matching_key(const dictionary *dict, const struct keyitem_pair *pair, const VARIANT *key, DWORD hash) +{ + if (is_string_key(key) && is_string_key(&pair->key)) { + if (hash != pair->hash) + return FALSE; + + return strcmp_key(dict, key, &pair->key) == 0; + } + + if ((is_string_key(key) && !is_string_key(&pair->key)) || + (!is_string_key(key) && is_string_key(&pair->key))) + return FALSE; + + /* for numeric keys only check hash */ + return hash == pair->hash; +} + +static struct keyitem_pair *get_keyitem_pair(dictionary *dict, VARIANT *key) +{ + struct keyitem_pair *pair; + DWORD bucket; + VARIANT hash; + HRESULT hr; + + hr = IDictionary_get_HashVal(&dict->IDictionary_iface, key, &hash); + if (FAILED(hr)) + return NULL; + + pair = get_bucket_head(dict, V_I4(&hash)); + if (!pair) + return NULL; + + bucket = pair->bucket; + + do { + if (is_matching_key(dict, pair, key, V_I4(&hash))) return pair; + pair = LIST_ENTRY(list_next(&dict->pairs, &pair->entry), struct keyitem_pair, entry); + if (pair && pair->bucket != bucket) break; + } while (pair != NULL); + + return NULL; +} + +static HRESULT add_keyitem_pair(dictionary *dict, VARIANT *key, VARIANT *item) +{ + struct keyitem_pair *pair, *head; + VARIANT hash; + HRESULT hr; + + hr = IDictionary_get_HashVal(&dict->IDictionary_iface, key, &hash); + if (FAILED(hr)) + return hr; + + pair = heap_alloc(sizeof(*pair)); + if (!pair) + return E_OUTOFMEMORY; + + pair->hash = V_I4(&hash); + pair->bucket = pair->hash % BUCKET_COUNT; + VariantInit(&pair->key); + VariantInit(&pair->item); + + hr = VariantCopyInd(&pair->key, key); + if (FAILED(hr)) + goto failed; + + hr = VariantCopyInd(&pair->item, item); + if (FAILED(hr)) + goto failed; + + head = get_bucket_head(dict, pair->hash); + if (head) + list_add_tail(&head->entry, &pair->entry); + else { + dict->buckets[pair->bucket] = pair; + list_add_tail(&dict->pairs, &pair->entry); + } + + dict->count++; + return S_OK; + +failed: + VariantClear(&pair->key); + VariantClear(&pair->item); + heap_free(pair); + return hr; +} + +static void free_keyitem_pair(struct keyitem_pair *pair) +{ + VariantClear(&pair->key); + VariantClear(&pair->item); + heap_free(pair); +} + static HRESULT WINAPI dictionary_QueryInterface(IDictionary *iface, REFIID riid, void **obj) { dictionary *This = impl_from_IDictionary(iface); @@ -100,8 +240,10 @@ static ULONG WINAPI dictionary_Release(IDictionary *iface) TRACE("(%p)\n", This); ref = InterlockedDecrement(&This->ref); - if(ref == 0) + if(ref == 0) { + IDictionary_RemoveAll(iface); heap_free(This); + } return ref; } @@ -192,13 +334,16 @@ static HRESULT WINAPI dictionary_get_Item(IDictionary *iface, VARIANT *Key, VARI return E_NOTIMPL; } -static HRESULT WINAPI dictionary_Add(IDictionary *iface, VARIANT *Key, VARIANT *Item) +static HRESULT WINAPI dictionary_Add(IDictionary *iface, VARIANT *key, VARIANT *item) { dictionary *This = impl_from_IDictionary(iface); - FIXME("(%p)->(%p %p)\n", This, Key, Item); + TRACE("(%p)->(%s %s)\n", This, debugstr_variant(key), debugstr_variant(item)); - return E_NOTIMPL; + if (get_keyitem_pair(This, key)) + return CTL_E_KEY_ALREADY_EXISTS; + + return add_keyitem_pair(This, key, item); } static HRESULT WINAPI dictionary_get_Count(IDictionary *iface, LONG *count) @@ -259,10 +404,21 @@ static HRESULT WINAPI dictionary_Remove(IDictionary *iface, VARIANT *Key) static HRESULT WINAPI dictionary_RemoveAll(IDictionary *iface) { dictionary *This = impl_from_IDictionary(iface); + struct keyitem_pair *pair, *pair2; - FIXME("(%p)->()\n", This); + TRACE("(%p)\n", This); - return E_NOTIMPL; + if (This->count == 0) + return S_OK; + + LIST_FOR_EACH_ENTRY_SAFE(pair, pair2, &This->pairs, struct keyitem_pair, entry) { + list_remove(&pair->entry); + free_keyitem_pair(pair); + } + memset(This->buckets, 0, sizeof(This->buckets)); + This->count = 0; + + return S_OK; } static HRESULT WINAPI dictionary_put_CompareMode(IDictionary *iface, CompareMethod method) @@ -415,6 +571,8 @@ HRESULT WINAPI Dictionary_CreateInstance(IClassFactory *factory,IUnknown *outer, This->ref = 1; This->method = BinaryCompare; This->count = 0; + list_init(&This->pairs); + memset(This->buckets, 0, sizeof(This->buckets)); *obj = &This->IDictionary_iface; diff --git a/dlls/scrrun/tests/dictionary.c b/dlls/scrrun/tests/dictionary.c index 25cb0a6406b..066a5ab0da9 100644 --- a/dlls/scrrun/tests/dictionary.c +++ b/dlls/scrrun/tests/dictionary.c @@ -64,7 +64,7 @@ static void test_interfaces(void) V_VT(&value) = VT_BSTR; V_BSTR(&value) = SysAllocString(key_add_value); hr = IDictionary_Add(dict, &key, &value); - todo_wine ok(hr == S_OK, "got 0x%08x, expected 0x%08x\n", hr, S_OK); + ok(hr == S_OK, "got 0x%08x, expected 0x%08x\n", hr, S_OK); VariantClear(&value); exists = VARIANT_FALSE; @@ -83,7 +83,7 @@ static void test_interfaces(void) hr = IDictionary_get_Count(dict, &count); ok(hr == S_OK, "got 0x%08x, expected 0x%08x\n", hr, S_OK); - todo_wine ok(count == 1, "got %d, expected 1\n", count); + ok(count == 1, "got %d, expected 1\n", count); IDictionary_Release(dict); IDispatch_Release(disp); @@ -128,11 +128,9 @@ if (0) /* crashes on native */ V_I2(&key) = 0; VariantInit(&item); hr = IDictionary_Add(dict, &key, &item); -todo_wine ok(hr == S_OK, "got 0x%08x\n", hr); hr = IDictionary_put_CompareMode(dict, BinaryCompare); -todo_wine ok(hr == CTL_E_ILLEGALFUNCTIONCALL, "got 0x%08x\n", hr); IDictionary_Release(dict); @@ -457,13 +455,11 @@ todo_wine { } VariantInit(&item); hr = IDictionary_Add(dict, &key, &item); -todo_wine ok(hr == S_OK, "got 0x%08x\n", hr); V_VT(&key) = VT_R4; V_R4(&key) = 0.0; hr = IDictionary_Add(dict, &key, &item); -todo_wine ok(hr == CTL_E_KEY_ALREADY_EXISTS, "got 0x%08x\n", hr); V_VT(&key) = VT_I2; @@ -519,7 +515,6 @@ todo_wine { V_R4(&key) = 0.0; VariantInit(&item); hr = IDictionary_Add(dict, &key, &item); -todo_wine ok(hr == S_OK, "got 0x%08x\n", hr); VariantInit(&keys); @@ -570,7 +565,6 @@ todo_wine VariantInit(&item); hr = IDictionary_Add(dict, &key, &item); -todo_wine ok(hr == S_OK, "got 0x%08x\n", hr); hr = IDictionary_Remove(dict, &key);