feat: add the LIGHT dataset and VDM
This commit is contained in:
parent
cb1d3dd68e
commit
f5552cde74
|
@ -0,0 +1,61 @@
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import typing as t
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import mashumaro
|
||||||
|
|
||||||
|
from waifu.datasets import BaseDataset
|
||||||
|
from waifu.utils.dataset import get_data_path
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LightDialogueAgent(mashumaro.DataClassDictMixin):
|
||||||
|
name: str
|
||||||
|
persona: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LightDialogueSetting(mashumaro.DataClassDictMixin):
|
||||||
|
name: str
|
||||||
|
category: str
|
||||||
|
description: str
|
||||||
|
background: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LightDialogueEpisode(mashumaro.DataClassDictMixin):
|
||||||
|
agents: t.List[LightDialogueAgent]
|
||||||
|
setting: LightDialogueSetting
|
||||||
|
character: t.List[str]
|
||||||
|
context: t.List[str]
|
||||||
|
room_objects: t.List[t.List[str]]
|
||||||
|
room_agents: t.List[t.List[str]]
|
||||||
|
all_descriptions: t.Dict[str, str]
|
||||||
|
available_actions: t.List[t.List[str]]
|
||||||
|
carrying: t.List[t.List[str]]
|
||||||
|
wielding: t.List[t.List[str]]
|
||||||
|
speech: t.List[str]
|
||||||
|
emote: t.List[str]
|
||||||
|
action: t.List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class LightDialogueDataset(BaseDataset[LightDialogueEpisode]):
|
||||||
|
'''
|
||||||
|
LIGHT: Learning in Interactive Games with Humans and Text
|
||||||
|
|
||||||
|
The LIGHT project is a large-scale fantasy text adventure game research
|
||||||
|
platform for training agents that can both talk and act, interacting either
|
||||||
|
with other models or with humans.
|
||||||
|
|
||||||
|
https://parl.ai/projects/light/
|
||||||
|
'''
|
||||||
|
|
||||||
|
def generator(self) -> t.Generator[LightDialogueEpisode, None, None]:
|
||||||
|
root_data_path = get_data_path("light_dialogue")
|
||||||
|
light_data_path = os.path.join(root_data_path, "light_data.pkl")
|
||||||
|
|
||||||
|
with open(light_data_path, "rb") as light_data_file:
|
||||||
|
light_data = pickle.load(light_data_file)
|
||||||
|
for episode in light_data:
|
||||||
|
yield LightDialogueEpisode.from_dict(episode)
|
|
@ -0,0 +1,47 @@
|
||||||
|
import typing as t
|
||||||
|
|
||||||
|
from waifu.datasets.light_dialogue import LightDialogueDataset
|
||||||
|
from waifu.modules import BaseModule
|
||||||
|
from waifu.utils.strings import normalize_string, title_case
|
||||||
|
|
||||||
|
|
||||||
|
class LightDialogueVDM(BaseModule):
|
||||||
|
'''Vanilla Dialogue Module based on the LIGHT dialogue dataset.'''
|
||||||
|
|
||||||
|
def generator(self) -> t.Generator[str, None, None]:
|
||||||
|
for episode in LightDialogueDataset():
|
||||||
|
# TODO(11b): Context and persona don't belong in a vanilla dialogue
|
||||||
|
# module.
|
||||||
|
context_message = f"Context: {episode.context[0]}\n"
|
||||||
|
|
||||||
|
persona_message = ""
|
||||||
|
for agent in episode.agents:
|
||||||
|
persona_message += f"{title_case(agent.name)}'s Description: {agent.persona}\n"
|
||||||
|
|
||||||
|
episode_messages: t.List[str] = [context_message, persona_message]
|
||||||
|
turn_count = len(episode.speech)
|
||||||
|
|
||||||
|
for idx in range(turn_count):
|
||||||
|
character = title_case(episode.character[idx])
|
||||||
|
speech = normalize_string(episode.speech[idx])
|
||||||
|
|
||||||
|
# Start off with just the actual speech dialogue.
|
||||||
|
message = speech
|
||||||
|
|
||||||
|
# If there was an action performed in that turn, add it to the
|
||||||
|
# string.
|
||||||
|
action = episode.action[idx]
|
||||||
|
if action is not None:
|
||||||
|
message += f" *{action}*"
|
||||||
|
|
||||||
|
# If there was an emote in that turn, add it to the string.
|
||||||
|
emote = episode.emote[idx]
|
||||||
|
if emote is not None:
|
||||||
|
message = f"*{emote}* {message}"
|
||||||
|
|
||||||
|
# Finally, prepend the turn character's name.
|
||||||
|
message = f"{character}: {message}"
|
||||||
|
|
||||||
|
episode_messages.append(message)
|
||||||
|
|
||||||
|
yield "\n".join(episode_messages)
|
|
@ -0,0 +1,19 @@
|
||||||
|
import os
|
||||||
|
import typing as t
|
||||||
|
|
||||||
|
HERE = os.path.realpath(os.path.dirname(__file__))
|
||||||
|
|
||||||
|
|
||||||
|
def get_data_path(dataset_name: t.Optional[str] = None) -> str:
|
||||||
|
'''
|
||||||
|
Returns an absolute path to either the data folder, or a specific dataset if
|
||||||
|
`dataset_name` is supplied.
|
||||||
|
'''
|
||||||
|
if 'WAIFU_DATA_PATH' in os.environ:
|
||||||
|
return os.environ['WAIFU_DATA_PATH']
|
||||||
|
|
||||||
|
components = [HERE, "..", "..", "data"]
|
||||||
|
if dataset_name:
|
||||||
|
components.append(dataset_name)
|
||||||
|
|
||||||
|
return os.path.join(*components)
|
|
@ -0,0 +1,62 @@
|
||||||
|
'''Utility functions to clean up text strings.'''
|
||||||
|
|
||||||
|
# Some of this is pasta from Meta's ParlAI. See:
|
||||||
|
# https://github.com/facebookresearch/ParlAI/blob/main/parlai/utils/strings.py
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_string(text: str, version: int = 1) -> str:
|
||||||
|
'''
|
||||||
|
Standardize the capitalization and punctuation spacing of the input text.
|
||||||
|
- Version 1: Fix sentence start casing and punctuation.
|
||||||
|
- Version 2: Add trailing period, if missing.
|
||||||
|
'''
|
||||||
|
|
||||||
|
switch_list = [(' .', '.'), (' ,', ','), (' ?', '?'), (' !', '!'),
|
||||||
|
(" ' ", "'")]
|
||||||
|
|
||||||
|
# add spaces so that words and punctuation can be seaprated
|
||||||
|
new_text = text.lower()
|
||||||
|
|
||||||
|
# normalize in case of human:
|
||||||
|
for new, old in switch_list:
|
||||||
|
new_text = new_text.replace(old, new).replace(' ', ' ')
|
||||||
|
|
||||||
|
# split on punctuation to find sentence boundaries
|
||||||
|
# capitalize stuff
|
||||||
|
tokens = new_text.split(' ')
|
||||||
|
for i in range(len(tokens)):
|
||||||
|
if i == 0:
|
||||||
|
tokens[i] = uppercase(tokens[i])
|
||||||
|
elif tokens[i] in ('i', "i'm", "i've", "i'll", "i'd"):
|
||||||
|
tokens[i] = uppercase(tokens[i])
|
||||||
|
elif tokens[i] in '?.!' and i < len(tokens) - 1:
|
||||||
|
tokens[i + 1] = uppercase(tokens[i + 1])
|
||||||
|
new_text = ' '.join(tokens)
|
||||||
|
new_text = ' ' + new_text + ' '
|
||||||
|
|
||||||
|
for tup in switch_list:
|
||||||
|
new_text = new_text.replace(tup[0], tup[1])
|
||||||
|
|
||||||
|
# get rid of surrounding whitespace
|
||||||
|
new_text = new_text.strip()
|
||||||
|
new_text = new_text.replace(' ', ' ')
|
||||||
|
|
||||||
|
if version > 1 and new_text and new_text[-1] not in '!.?)"\'':
|
||||||
|
new_text += '.'
|
||||||
|
|
||||||
|
return new_text
|
||||||
|
|
||||||
|
|
||||||
|
def title_case(string: str) -> str:
|
||||||
|
'''Converts a string into Title Case.'''
|
||||||
|
return " ".join([uppercase(word) for word in string.split(" ")])
|
||||||
|
|
||||||
|
|
||||||
|
def uppercase(string: str) -> str:
|
||||||
|
'''
|
||||||
|
Make the first character of the string uppercase, if the string is non-empty.
|
||||||
|
'''
|
||||||
|
if len(string) == 0:
|
||||||
|
return string
|
||||||
|
else:
|
||||||
|
return string[0].upper() + string[1:]
|
Loading…
Reference in New Issue