feat: add SODA dataset

* Very first prototype of SODA dataset support

I'm also bringing over the version of PromptConstants from the dev branch due to needing CHAT_START_TOKEN

* More flexibility when fetching speaker names

* Make SODA a PDM instead of a VDM

* Swap order of speakers based on relation

* Oh, and fix a typo too

* Bugfix
This commit is contained in:
TearGosling 2023-01-08 12:15:15 -06:00 committed by 0x000011b
parent eb997a3d3f
commit ea162de2e0
2 changed files with 91 additions and 0 deletions

43
waifu/datasets/soda.py Normal file
View File

@ -0,0 +1,43 @@
import os
import pickle
import typing as t
from dataclasses import dataclass
import mashumaro
import pandas as pd
from waifu.datasets import BaseDataset
from waifu.utils.dataset import get_data_path
@dataclass(frozen=True)
class SodaEpisode(mashumaro.DataClassDictMixin):
narrative: str
dialogue: t.List[str]
speakers: t.List[str]
relation: str
literal: str
class SodaDataset(BaseDataset[SodaEpisode]):
'''
SODA: Million-scale Dialogue Distillation with Social Commonsense
Contextualization
https://huggingface.co/datasets/allenai/soda
'''
def generator(self) -> t.Generator[SodaEpisode, None, None]:
root_data_path = get_data_path("soda")
file_path = os.path.join(root_data_path, "test.parquet")
df = pd.read_parquet(file_path)
# Iterate through the test part of the SODA dataset
for i in df.index:
yield SodaEpisode(
narrative=df['narrative'][i],
dialogue=df['dialogue'][i],
speakers=df['speakers'][i],
relation=df['relation'][i],
literal=df['literal'][i]
)

48
waifu/modules/soda_pdm.py Normal file
View File

@ -0,0 +1,48 @@
import typing as t
from waifu.core.consts import PromptConstants
from waifu.datasets.soda import SodaDataset
from waifu.modules import BaseModule
class SodaPDM(BaseModule):
'''Persona Dialogue Module based on the SODA dataset.'''
def generator(self) -> t.Generator[list[str], None, None]:
for episode in SodaDataset():
episode_messages = []
# NOTE(TG): We determine which order the speakers go on based on whether the relation is xAttr or not.
# This is because some speakers are more abstract concepts rather than concrete names,
# which would make them much more suitable as a bot
if episode.relation == "xAttr":
bot_name = episode.speakers[0]
user_name = episode.speakers[1]
else:
user_name = episode.speakers[0]
bot_name = episode.speakers[1]
# First, we would want to set the persona.
# However, the only acceptable description of a persona would be when episode.relation is "xAttr", since that directly describes
# a person in the conversation.
if episode.relation == "xAttr":
episode_messages.append(f"{PromptConstants.pdm_prefix_for(bot_name)}: {episode.literal}")
# Next, set the scenario.
# Make sure to replace any instance of the person representing the user in the conversation with the user token
replaced_narrative = episode.narrative.replace(user_name, PromptConstants.USER_TOKEN)
scenario = f"Scenario: {replaced_narrative}"
episode_messages.append(scenario)
# Next, the start token
episode_messages.append(PromptConstants.CHAT_START_TOKEN)
# I am going to assume that the length of episode.speakers is the same as the length of episode.dialogue
# Looked pretty clean to me in the data. Fuck it, TODO: account for the possibility of that happening
for i, utterance in enumerate(episode.dialogue):
# For now, just leave bot's name unreplaced.
if episode.speakers[i] == user_name:
name = PromptConstants.USER_PREFIX
else:
name = bot_name
episode_messages.append(f"{name}: {utterance.replace(user_name, PromptConstants.USER_TOKEN)}")
yield episode_messages