From ea162de2e0b5e263876897af27892ff594bdee54 Mon Sep 17 00:00:00 2001 From: TearGosling <119216654+TearGosling@users.noreply.github.com> Date: Sun, 8 Jan 2023 12:15:15 -0600 Subject: [PATCH] feat: add SODA dataset * Very first prototype of SODA dataset support I'm also bringing over the version of PromptConstants from the dev branch due to needing CHAT_START_TOKEN * More flexibility when fetching speaker names * Make SODA a PDM instead of a VDM * Swap order of speakers based on relation * Oh, and fix a typo too * Bugfix --- waifu/datasets/soda.py | 43 +++++++++++++++++++++++++++++++++++ waifu/modules/soda_pdm.py | 48 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+) create mode 100644 waifu/datasets/soda.py create mode 100644 waifu/modules/soda_pdm.py diff --git a/waifu/datasets/soda.py b/waifu/datasets/soda.py new file mode 100644 index 0000000..14f4f33 --- /dev/null +++ b/waifu/datasets/soda.py @@ -0,0 +1,43 @@ +import os +import pickle +import typing as t +from dataclasses import dataclass + +import mashumaro +import pandas as pd + +from waifu.datasets import BaseDataset +from waifu.utils.dataset import get_data_path + +@dataclass(frozen=True) +class SodaEpisode(mashumaro.DataClassDictMixin): + narrative: str + dialogue: t.List[str] + speakers: t.List[str] + relation: str + literal: str + +class SodaDataset(BaseDataset[SodaEpisode]): + ''' + SODA: Million-scale Dialogue Distillation with Social Commonsense + Contextualization + + https://huggingface.co/datasets/allenai/soda + ''' + + def generator(self) -> t.Generator[SodaEpisode, None, None]: + root_data_path = get_data_path("soda") + file_path = os.path.join(root_data_path, "test.parquet") + df = pd.read_parquet(file_path) + + # Iterate through the test part of the SODA dataset + for i in df.index: + yield SodaEpisode( + narrative=df['narrative'][i], + dialogue=df['dialogue'][i], + speakers=df['speakers'][i], + relation=df['relation'][i], + literal=df['literal'][i] + ) + + \ No newline at end of file diff --git a/waifu/modules/soda_pdm.py b/waifu/modules/soda_pdm.py new file mode 100644 index 0000000..14d3212 --- /dev/null +++ b/waifu/modules/soda_pdm.py @@ -0,0 +1,48 @@ +import typing as t + +from waifu.core.consts import PromptConstants +from waifu.datasets.soda import SodaDataset +from waifu.modules import BaseModule + + +class SodaPDM(BaseModule): + '''Persona Dialogue Module based on the SODA dataset.''' + + def generator(self) -> t.Generator[list[str], None, None]: + for episode in SodaDataset(): + episode_messages = [] + # NOTE(TG): We determine which order the speakers go on based on whether the relation is xAttr or not. + # This is because some speakers are more abstract concepts rather than concrete names, + # which would make them much more suitable as a bot + if episode.relation == "xAttr": + bot_name = episode.speakers[0] + user_name = episode.speakers[1] + else: + user_name = episode.speakers[0] + bot_name = episode.speakers[1] + + # First, we would want to set the persona. + # However, the only acceptable description of a persona would be when episode.relation is "xAttr", since that directly describes + # a person in the conversation. + if episode.relation == "xAttr": + episode_messages.append(f"{PromptConstants.pdm_prefix_for(bot_name)}: {episode.literal}") + + # Next, set the scenario. + # Make sure to replace any instance of the person representing the user in the conversation with the user token + replaced_narrative = episode.narrative.replace(user_name, PromptConstants.USER_TOKEN) + scenario = f"Scenario: {replaced_narrative}" + episode_messages.append(scenario) + # Next, the start token + episode_messages.append(PromptConstants.CHAT_START_TOKEN) + + # I am going to assume that the length of episode.speakers is the same as the length of episode.dialogue + # Looked pretty clean to me in the data. Fuck it, TODO: account for the possibility of that happening + for i, utterance in enumerate(episode.dialogue): + # For now, just leave bot's name unreplaced. + if episode.speakers[i] == user_name: + name = PromptConstants.USER_PREFIX + else: + name = bot_name + episode_messages.append(f"{name}: {utterance.replace(user_name, PromptConstants.USER_TOKEN)}") + + yield episode_messages \ No newline at end of file