From f5552cde74f2237a386e89dbe090315ae08b8cc0 Mon Sep 17 00:00:00 2001 From: 0x000011b <0x000011b@waifu.club> Date: Sat, 17 Dec 2022 21:32:34 -0300 Subject: [PATCH] feat: add the LIGHT dataset and VDM --- waifu/datasets/light_dialogue.py | 61 ++++++++++++++++++++++++++++ waifu/modules/light_dialogue_vdm.py | 47 ++++++++++++++++++++++ waifu/utils/dataset.py | 19 +++++++++ waifu/utils/strings.py | 62 +++++++++++++++++++++++++++++ 4 files changed, 189 insertions(+) create mode 100644 waifu/datasets/light_dialogue.py create mode 100644 waifu/modules/light_dialogue_vdm.py create mode 100644 waifu/utils/dataset.py create mode 100644 waifu/utils/strings.py diff --git a/waifu/datasets/light_dialogue.py b/waifu/datasets/light_dialogue.py new file mode 100644 index 0000000..2f2c86f --- /dev/null +++ b/waifu/datasets/light_dialogue.py @@ -0,0 +1,61 @@ +import os +import pickle +import typing as t +from dataclasses import dataclass + +import mashumaro + +from waifu.datasets import BaseDataset +from waifu.utils.dataset import get_data_path + + +@dataclass(frozen=True) +class LightDialogueAgent(mashumaro.DataClassDictMixin): + name: str + persona: str + + +@dataclass(frozen=True) +class LightDialogueSetting(mashumaro.DataClassDictMixin): + name: str + category: str + description: str + background: str + + +@dataclass(frozen=True) +class LightDialogueEpisode(mashumaro.DataClassDictMixin): + agents: t.List[LightDialogueAgent] + setting: LightDialogueSetting + character: t.List[str] + context: t.List[str] + room_objects: t.List[t.List[str]] + room_agents: t.List[t.List[str]] + all_descriptions: t.Dict[str, str] + available_actions: t.List[t.List[str]] + carrying: t.List[t.List[str]] + wielding: t.List[t.List[str]] + speech: t.List[str] + emote: t.List[str] + action: t.List[str] + + +class LightDialogueDataset(BaseDataset[LightDialogueEpisode]): + ''' + LIGHT: Learning in Interactive Games with Humans and Text + + The LIGHT project is a large-scale fantasy text adventure game research + platform for training agents that can both talk and act, interacting either + with other models or with humans. + + https://parl.ai/projects/light/ + ''' + + def generator(self) -> t.Generator[LightDialogueEpisode, None, None]: + root_data_path = get_data_path("light_dialogue") + light_data_path = os.path.join(root_data_path, "light_data.pkl") + + with open(light_data_path, "rb") as light_data_file: + light_data = pickle.load(light_data_file) + for episode in light_data: + yield LightDialogueEpisode.from_dict(episode) diff --git a/waifu/modules/light_dialogue_vdm.py b/waifu/modules/light_dialogue_vdm.py new file mode 100644 index 0000000..7e9588f --- /dev/null +++ b/waifu/modules/light_dialogue_vdm.py @@ -0,0 +1,47 @@ +import typing as t + +from waifu.datasets.light_dialogue import LightDialogueDataset +from waifu.modules import BaseModule +from waifu.utils.strings import normalize_string, title_case + + +class LightDialogueVDM(BaseModule): + '''Vanilla Dialogue Module based on the LIGHT dialogue dataset.''' + + def generator(self) -> t.Generator[str, None, None]: + for episode in LightDialogueDataset(): + # TODO(11b): Context and persona don't belong in a vanilla dialogue + # module. + context_message = f"Context: {episode.context[0]}\n" + + persona_message = "" + for agent in episode.agents: + persona_message += f"{title_case(agent.name)}'s Description: {agent.persona}\n" + + episode_messages: t.List[str] = [context_message, persona_message] + turn_count = len(episode.speech) + + for idx in range(turn_count): + character = title_case(episode.character[idx]) + speech = normalize_string(episode.speech[idx]) + + # Start off with just the actual speech dialogue. + message = speech + + # If there was an action performed in that turn, add it to the + # string. + action = episode.action[idx] + if action is not None: + message += f" *{action}*" + + # If there was an emote in that turn, add it to the string. + emote = episode.emote[idx] + if emote is not None: + message = f"*{emote}* {message}" + + # Finally, prepend the turn character's name. + message = f"{character}: {message}" + + episode_messages.append(message) + + yield "\n".join(episode_messages) diff --git a/waifu/utils/dataset.py b/waifu/utils/dataset.py new file mode 100644 index 0000000..62809b3 --- /dev/null +++ b/waifu/utils/dataset.py @@ -0,0 +1,19 @@ +import os +import typing as t + +HERE = os.path.realpath(os.path.dirname(__file__)) + + +def get_data_path(dataset_name: t.Optional[str] = None) -> str: + ''' + Returns an absolute path to either the data folder, or a specific dataset if + `dataset_name` is supplied. + ''' + if 'WAIFU_DATA_PATH' in os.environ: + return os.environ['WAIFU_DATA_PATH'] + + components = [HERE, "..", "..", "data"] + if dataset_name: + components.append(dataset_name) + + return os.path.join(*components) diff --git a/waifu/utils/strings.py b/waifu/utils/strings.py new file mode 100644 index 0000000..eb92fca --- /dev/null +++ b/waifu/utils/strings.py @@ -0,0 +1,62 @@ +'''Utility functions to clean up text strings.''' + +# Some of this is pasta from Meta's ParlAI. See: +# https://github.com/facebookresearch/ParlAI/blob/main/parlai/utils/strings.py + + +def normalize_string(text: str, version: int = 1) -> str: + ''' + Standardize the capitalization and punctuation spacing of the input text. + - Version 1: Fix sentence start casing and punctuation. + - Version 2: Add trailing period, if missing. + ''' + + switch_list = [(' .', '.'), (' ,', ','), (' ?', '?'), (' !', '!'), + (" ' ", "'")] + + # add spaces so that words and punctuation can be seaprated + new_text = text.lower() + + # normalize in case of human: + for new, old in switch_list: + new_text = new_text.replace(old, new).replace(' ', ' ') + + # split on punctuation to find sentence boundaries + # capitalize stuff + tokens = new_text.split(' ') + for i in range(len(tokens)): + if i == 0: + tokens[i] = uppercase(tokens[i]) + elif tokens[i] in ('i', "i'm", "i've", "i'll", "i'd"): + tokens[i] = uppercase(tokens[i]) + elif tokens[i] in '?.!' and i < len(tokens) - 1: + tokens[i + 1] = uppercase(tokens[i + 1]) + new_text = ' '.join(tokens) + new_text = ' ' + new_text + ' ' + + for tup in switch_list: + new_text = new_text.replace(tup[0], tup[1]) + + # get rid of surrounding whitespace + new_text = new_text.strip() + new_text = new_text.replace(' ', ' ') + + if version > 1 and new_text and new_text[-1] not in '!.?)"\'': + new_text += '.' + + return new_text + + +def title_case(string: str) -> str: + '''Converts a string into Title Case.''' + return " ".join([uppercase(word) for word in string.split(" ")]) + + +def uppercase(string: str) -> str: + ''' + Make the first character of the string uppercase, if the string is non-empty. + ''' + if len(string) == 0: + return string + else: + return string[0].upper() + string[1:]