feat: add the LIGHT dataset and VDM
This commit is contained in:
parent
cb1d3dd68e
commit
f5552cde74
|
@ -0,0 +1,61 @@
|
|||
import os
|
||||
import pickle
|
||||
import typing as t
|
||||
from dataclasses import dataclass
|
||||
|
||||
import mashumaro
|
||||
|
||||
from waifu.datasets import BaseDataset
|
||||
from waifu.utils.dataset import get_data_path
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LightDialogueAgent(mashumaro.DataClassDictMixin):
|
||||
name: str
|
||||
persona: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LightDialogueSetting(mashumaro.DataClassDictMixin):
|
||||
name: str
|
||||
category: str
|
||||
description: str
|
||||
background: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LightDialogueEpisode(mashumaro.DataClassDictMixin):
|
||||
agents: t.List[LightDialogueAgent]
|
||||
setting: LightDialogueSetting
|
||||
character: t.List[str]
|
||||
context: t.List[str]
|
||||
room_objects: t.List[t.List[str]]
|
||||
room_agents: t.List[t.List[str]]
|
||||
all_descriptions: t.Dict[str, str]
|
||||
available_actions: t.List[t.List[str]]
|
||||
carrying: t.List[t.List[str]]
|
||||
wielding: t.List[t.List[str]]
|
||||
speech: t.List[str]
|
||||
emote: t.List[str]
|
||||
action: t.List[str]
|
||||
|
||||
|
||||
class LightDialogueDataset(BaseDataset[LightDialogueEpisode]):
|
||||
'''
|
||||
LIGHT: Learning in Interactive Games with Humans and Text
|
||||
|
||||
The LIGHT project is a large-scale fantasy text adventure game research
|
||||
platform for training agents that can both talk and act, interacting either
|
||||
with other models or with humans.
|
||||
|
||||
https://parl.ai/projects/light/
|
||||
'''
|
||||
|
||||
def generator(self) -> t.Generator[LightDialogueEpisode, None, None]:
|
||||
root_data_path = get_data_path("light_dialogue")
|
||||
light_data_path = os.path.join(root_data_path, "light_data.pkl")
|
||||
|
||||
with open(light_data_path, "rb") as light_data_file:
|
||||
light_data = pickle.load(light_data_file)
|
||||
for episode in light_data:
|
||||
yield LightDialogueEpisode.from_dict(episode)
|
|
@ -0,0 +1,47 @@
|
|||
import typing as t
|
||||
|
||||
from waifu.datasets.light_dialogue import LightDialogueDataset
|
||||
from waifu.modules import BaseModule
|
||||
from waifu.utils.strings import normalize_string, title_case
|
||||
|
||||
|
||||
class LightDialogueVDM(BaseModule):
|
||||
'''Vanilla Dialogue Module based on the LIGHT dialogue dataset.'''
|
||||
|
||||
def generator(self) -> t.Generator[str, None, None]:
|
||||
for episode in LightDialogueDataset():
|
||||
# TODO(11b): Context and persona don't belong in a vanilla dialogue
|
||||
# module.
|
||||
context_message = f"Context: {episode.context[0]}\n"
|
||||
|
||||
persona_message = ""
|
||||
for agent in episode.agents:
|
||||
persona_message += f"{title_case(agent.name)}'s Description: {agent.persona}\n"
|
||||
|
||||
episode_messages: t.List[str] = [context_message, persona_message]
|
||||
turn_count = len(episode.speech)
|
||||
|
||||
for idx in range(turn_count):
|
||||
character = title_case(episode.character[idx])
|
||||
speech = normalize_string(episode.speech[idx])
|
||||
|
||||
# Start off with just the actual speech dialogue.
|
||||
message = speech
|
||||
|
||||
# If there was an action performed in that turn, add it to the
|
||||
# string.
|
||||
action = episode.action[idx]
|
||||
if action is not None:
|
||||
message += f" *{action}*"
|
||||
|
||||
# If there was an emote in that turn, add it to the string.
|
||||
emote = episode.emote[idx]
|
||||
if emote is not None:
|
||||
message = f"*{emote}* {message}"
|
||||
|
||||
# Finally, prepend the turn character's name.
|
||||
message = f"{character}: {message}"
|
||||
|
||||
episode_messages.append(message)
|
||||
|
||||
yield "\n".join(episode_messages)
|
|
@ -0,0 +1,19 @@
|
|||
import os
|
||||
import typing as t
|
||||
|
||||
HERE = os.path.realpath(os.path.dirname(__file__))
|
||||
|
||||
|
||||
def get_data_path(dataset_name: t.Optional[str] = None) -> str:
|
||||
'''
|
||||
Returns an absolute path to either the data folder, or a specific dataset if
|
||||
`dataset_name` is supplied.
|
||||
'''
|
||||
if 'WAIFU_DATA_PATH' in os.environ:
|
||||
return os.environ['WAIFU_DATA_PATH']
|
||||
|
||||
components = [HERE, "..", "..", "data"]
|
||||
if dataset_name:
|
||||
components.append(dataset_name)
|
||||
|
||||
return os.path.join(*components)
|
|
@ -0,0 +1,62 @@
|
|||
'''Utility functions to clean up text strings.'''
|
||||
|
||||
# Some of this is pasta from Meta's ParlAI. See:
|
||||
# https://github.com/facebookresearch/ParlAI/blob/main/parlai/utils/strings.py
|
||||
|
||||
|
||||
def normalize_string(text: str, version: int = 1) -> str:
|
||||
'''
|
||||
Standardize the capitalization and punctuation spacing of the input text.
|
||||
- Version 1: Fix sentence start casing and punctuation.
|
||||
- Version 2: Add trailing period, if missing.
|
||||
'''
|
||||
|
||||
switch_list = [(' .', '.'), (' ,', ','), (' ?', '?'), (' !', '!'),
|
||||
(" ' ", "'")]
|
||||
|
||||
# add spaces so that words and punctuation can be seaprated
|
||||
new_text = text.lower()
|
||||
|
||||
# normalize in case of human:
|
||||
for new, old in switch_list:
|
||||
new_text = new_text.replace(old, new).replace(' ', ' ')
|
||||
|
||||
# split on punctuation to find sentence boundaries
|
||||
# capitalize stuff
|
||||
tokens = new_text.split(' ')
|
||||
for i in range(len(tokens)):
|
||||
if i == 0:
|
||||
tokens[i] = uppercase(tokens[i])
|
||||
elif tokens[i] in ('i', "i'm", "i've", "i'll", "i'd"):
|
||||
tokens[i] = uppercase(tokens[i])
|
||||
elif tokens[i] in '?.!' and i < len(tokens) - 1:
|
||||
tokens[i + 1] = uppercase(tokens[i + 1])
|
||||
new_text = ' '.join(tokens)
|
||||
new_text = ' ' + new_text + ' '
|
||||
|
||||
for tup in switch_list:
|
||||
new_text = new_text.replace(tup[0], tup[1])
|
||||
|
||||
# get rid of surrounding whitespace
|
||||
new_text = new_text.strip()
|
||||
new_text = new_text.replace(' ', ' ')
|
||||
|
||||
if version > 1 and new_text and new_text[-1] not in '!.?)"\'':
|
||||
new_text += '.'
|
||||
|
||||
return new_text
|
||||
|
||||
|
||||
def title_case(string: str) -> str:
|
||||
'''Converts a string into Title Case.'''
|
||||
return " ".join([uppercase(word) for word in string.split(" ")])
|
||||
|
||||
|
||||
def uppercase(string: str) -> str:
|
||||
'''
|
||||
Make the first character of the string uppercase, if the string is non-empty.
|
||||
'''
|
||||
if len(string) == 0:
|
||||
return string
|
||||
else:
|
||||
return string[0].upper() + string[1:]
|
Loading…
Reference in New Issue