feat: add the LIGHT dataset and VDM

This commit is contained in:
11b 2022-12-17 21:32:34 -03:00
parent cb1d3dd68e
commit f5552cde74
4 changed files with 189 additions and 0 deletions

View File

@ -0,0 +1,61 @@
import os
import pickle
import typing as t
from dataclasses import dataclass
import mashumaro
from waifu.datasets import BaseDataset
from waifu.utils.dataset import get_data_path
@dataclass(frozen=True)
class LightDialogueAgent(mashumaro.DataClassDictMixin):
name: str
persona: str
@dataclass(frozen=True)
class LightDialogueSetting(mashumaro.DataClassDictMixin):
name: str
category: str
description: str
background: str
@dataclass(frozen=True)
class LightDialogueEpisode(mashumaro.DataClassDictMixin):
agents: t.List[LightDialogueAgent]
setting: LightDialogueSetting
character: t.List[str]
context: t.List[str]
room_objects: t.List[t.List[str]]
room_agents: t.List[t.List[str]]
all_descriptions: t.Dict[str, str]
available_actions: t.List[t.List[str]]
carrying: t.List[t.List[str]]
wielding: t.List[t.List[str]]
speech: t.List[str]
emote: t.List[str]
action: t.List[str]
class LightDialogueDataset(BaseDataset[LightDialogueEpisode]):
'''
LIGHT: Learning in Interactive Games with Humans and Text
The LIGHT project is a large-scale fantasy text adventure game research
platform for training agents that can both talk and act, interacting either
with other models or with humans.
https://parl.ai/projects/light/
'''
def generator(self) -> t.Generator[LightDialogueEpisode, None, None]:
root_data_path = get_data_path("light_dialogue")
light_data_path = os.path.join(root_data_path, "light_data.pkl")
with open(light_data_path, "rb") as light_data_file:
light_data = pickle.load(light_data_file)
for episode in light_data:
yield LightDialogueEpisode.from_dict(episode)

View File

@ -0,0 +1,47 @@
import typing as t
from waifu.datasets.light_dialogue import LightDialogueDataset
from waifu.modules import BaseModule
from waifu.utils.strings import normalize_string, title_case
class LightDialogueVDM(BaseModule):
'''Vanilla Dialogue Module based on the LIGHT dialogue dataset.'''
def generator(self) -> t.Generator[str, None, None]:
for episode in LightDialogueDataset():
# TODO(11b): Context and persona don't belong in a vanilla dialogue
# module.
context_message = f"Context: {episode.context[0]}\n"
persona_message = ""
for agent in episode.agents:
persona_message += f"{title_case(agent.name)}'s Description: {agent.persona}\n"
episode_messages: t.List[str] = [context_message, persona_message]
turn_count = len(episode.speech)
for idx in range(turn_count):
character = title_case(episode.character[idx])
speech = normalize_string(episode.speech[idx])
# Start off with just the actual speech dialogue.
message = speech
# If there was an action performed in that turn, add it to the
# string.
action = episode.action[idx]
if action is not None:
message += f" *{action}*"
# If there was an emote in that turn, add it to the string.
emote = episode.emote[idx]
if emote is not None:
message = f"*{emote}* {message}"
# Finally, prepend the turn character's name.
message = f"{character}: {message}"
episode_messages.append(message)
yield "\n".join(episode_messages)

19
waifu/utils/dataset.py Normal file
View File

@ -0,0 +1,19 @@
import os
import typing as t
HERE = os.path.realpath(os.path.dirname(__file__))
def get_data_path(dataset_name: t.Optional[str] = None) -> str:
'''
Returns an absolute path to either the data folder, or a specific dataset if
`dataset_name` is supplied.
'''
if 'WAIFU_DATA_PATH' in os.environ:
return os.environ['WAIFU_DATA_PATH']
components = [HERE, "..", "..", "data"]
if dataset_name:
components.append(dataset_name)
return os.path.join(*components)

62
waifu/utils/strings.py Normal file
View File

@ -0,0 +1,62 @@
'''Utility functions to clean up text strings.'''
# Some of this is pasta from Meta's ParlAI. See:
# https://github.com/facebookresearch/ParlAI/blob/main/parlai/utils/strings.py
def normalize_string(text: str, version: int = 1) -> str:
'''
Standardize the capitalization and punctuation spacing of the input text.
- Version 1: Fix sentence start casing and punctuation.
- Version 2: Add trailing period, if missing.
'''
switch_list = [(' .', '.'), (' ,', ','), (' ?', '?'), (' !', '!'),
(" ' ", "'")]
# add spaces so that words and punctuation can be seaprated
new_text = text.lower()
# normalize in case of human:
for new, old in switch_list:
new_text = new_text.replace(old, new).replace(' ', ' ')
# split on punctuation to find sentence boundaries
# capitalize stuff
tokens = new_text.split(' ')
for i in range(len(tokens)):
if i == 0:
tokens[i] = uppercase(tokens[i])
elif tokens[i] in ('i', "i'm", "i've", "i'll", "i'd"):
tokens[i] = uppercase(tokens[i])
elif tokens[i] in '?.!' and i < len(tokens) - 1:
tokens[i + 1] = uppercase(tokens[i + 1])
new_text = ' '.join(tokens)
new_text = ' ' + new_text + ' '
for tup in switch_list:
new_text = new_text.replace(tup[0], tup[1])
# get rid of surrounding whitespace
new_text = new_text.strip()
new_text = new_text.replace(' ', ' ')
if version > 1 and new_text and new_text[-1] not in '!.?)"\'':
new_text += '.'
return new_text
def title_case(string: str) -> str:
'''Converts a string into Title Case.'''
return " ".join([uppercase(word) for word in string.split(" ")])
def uppercase(string: str) -> str:
'''
Make the first character of the string uppercase, if the string is non-empty.
'''
if len(string) == 0:
return string
else:
return string[0].upper() + string[1:]