diff --git a/waifu/core/consts.py b/waifu/core/consts.py index d007e94..1596ae1 100644 --- a/waifu/core/consts.py +++ b/waifu/core/consts.py @@ -10,9 +10,16 @@ class PromptConstants: # Token to be replaced by the bot's name. BOT_TOKEN = "" + # Should be kept in sync with the relevant model that will be trained. This + # is taken from EleutherAI's Pythia (so, GPT-NeoX). + EOS_TOKEN = "<|endoftext|>" + + # Token to separate prompt trickery from actual dialogue. + CHAT_START_TOKEN = "" + # Global target word count. The word count is chosen in such a way that we # can fit all the required prompt trickery into the model's input, but still - # leave enough space for the user's input message and the infernce result. + # leave enough space for the user's input message and the inference result. TARGET_WORD_COUNT_PER_EPISODE = 1024 @staticmethod diff --git a/waifu/modules/__init__.py b/waifu/modules/__init__.py index 540c4a1..a668313 100644 --- a/waifu/modules/__init__.py +++ b/waifu/modules/__init__.py @@ -4,13 +4,13 @@ import typing as t class BaseModule: '''Base module class.''' - def __iter__(self) -> t.Generator[str, None, None]: + def __iter__(self) -> t.Generator[list[str], None, None]: '''Implements the basic iterator interface.''' return self.generator() - def generator(self) -> t.Generator[str, None, None]: + def generator(self) -> t.Generator[list[str], None, None]: ''' - Should yield strings that will be used in the model's training / + Should yield dialogue turns that will be used in the model's training / validation / test splits. ''' raise NotImplementedError diff --git a/waifu/modules/characterai_pdm.py b/waifu/modules/characterai_pdm.py index b0b5b3f..d9ae3ee 100644 --- a/waifu/modules/characterai_pdm.py +++ b/waifu/modules/characterai_pdm.py @@ -39,7 +39,7 @@ EPISODE_SIMILARITY_THRESHOLD = 0.55 class CharacterAiPDM(BaseModule): '''A Persona Dialogue Module powered by CharacterAI data.''' - def generator(self) -> t.Generator[str, None, None]: + def generator(self) -> t.Generator[list[str], None, None]: for chat in CharacterAiDataset(): if len(chat.messages) < MIN_EPISODE_LEN: logger.debug( @@ -58,10 +58,10 @@ class CharacterAiPDM(BaseModule): chat.bot.name, chat.bot.definitions) base_turns.append(parsed_definitions) - # Add an empty turn to separate persona info from messages, if + # Add turn to separate persona info from messages, if # necessary. if len(base_turns) > 0: - base_turns.append("") + base_turns.append(PromptConstants.CHAT_START_TOKEN) # Now, start adding messages and break episodes apart if they get # too big. @@ -101,7 +101,8 @@ class CharacterAiPDM(BaseModule): # target word count, so we return the episode without it... removed_turn = turns.pop() if average_similarity_score_for_episode <= EPISODE_SIMILARITY_THRESHOLD: - yield "\n".join(turns) + # yield "\n".join(turns) + yield turns else: logger.debug( "Ignoring episode due to high similarity between messages (%s > %s)", diff --git a/waifu/modules/discord_vdm.py b/waifu/modules/discord_vdm.py index 3b31051..e0e2a24 100644 --- a/waifu/modules/discord_vdm.py +++ b/waifu/modules/discord_vdm.py @@ -31,7 +31,7 @@ logger = logging.getLogger(__name__) class DiscordVDM(BaseModule): '''A Vanilla Dialogue Module powered by Discord dumps.''' - def generator(self) -> t.Generator[str, None, None]: + def generator(self) -> t.Generator[list[str], None, None]: root_data_path = get_data_path("discord") db_path = os.path.join(root_data_path, "archive.dht") db = sqlite3.connect(db_path) @@ -62,7 +62,7 @@ class DiscordVDM(BaseModule): avg) continue - yield "\n".join(turns) + yield turns # diff --git a/waifu/modules/kajiwoto_vdm.py b/waifu/modules/kajiwoto_vdm.py index 657aef3..66b19ce 100644 --- a/waifu/modules/kajiwoto_vdm.py +++ b/waifu/modules/kajiwoto_vdm.py @@ -9,7 +9,7 @@ from waifu.modules import BaseModule class KajiwotoVDM(BaseModule): '''A Vanilla Dialogue Module powered by the Kajiwoto dataset.''' - def generator(self) -> t.Generator[str, None, None]: + def generator(self) -> t.Generator[list[str], None, None]: dataset = KajiwotoDataset() for episode in dataset: turns: t.List[str] = [] @@ -23,4 +23,4 @@ class KajiwotoVDM(BaseModule): processed_string = replace_special_tokens_in(string) for generated_string in generate_variants_for(processed_string): - yield generated_string + yield generated_string.split("\n") diff --git a/waifu/modules/light_dialogue_pdm.py b/waifu/modules/light_dialogue_pdm.py index f1d5255..d383d46 100644 --- a/waifu/modules/light_dialogue_pdm.py +++ b/waifu/modules/light_dialogue_pdm.py @@ -9,10 +9,10 @@ from waifu.utils.strings import normalize_string, title_case class LightDialoguePDM(BaseModule): '''Persona Dialogue Module based on the LIGHT dataset.''' - def generator(self) -> t.Generator[str, None, None]: + def generator(self) -> t.Generator[list[str], None, None]: for episode in LightDialogueDataset(): # TODO(11b): Scenario doesn't belong in a persona dialog module. - context_message = f"Context: {episode.context[0]}\n" + context_message = f"Scenario: {episode.context[0]}\n" persona_message = "" for agent in episode.agents: @@ -48,4 +48,4 @@ class LightDialoguePDM(BaseModule): episode_messages.append(message) - yield "\n".join(episode_messages) + yield episode_messages diff --git a/waifu/scripts/build_dataset.py b/waifu/scripts/build_dataset.py index 8678da6..5f1d69f 100755 --- a/waifu/scripts/build_dataset.py +++ b/waifu/scripts/build_dataset.py @@ -5,10 +5,12 @@ import importlib import json import logging import os +import random import subprocess import sys import typing as t +from waifu.core.consts import PromptConstants from waifu.modules import BaseModule from waifu.utils.strings import contains_suspect_unicode @@ -16,17 +18,19 @@ from waifu.utils.strings import contains_suspect_unicode # metaprogramming trickery to build this list out instead. DEFAULT_MODULE_LIST = [ "characterai_pdm:CharacterAiPDM", - "discord_vdm:DiscordVDM", + # "discord_vdm:DiscordVDM", # KajiwotoPDM has a bunch of garbage I need to filter, disabling in favor # of the vanilla dialogue module for now. # "kajiwoto_pdm:KajiwotoPDM", - "kajiwoto_vdm:KajiwotoVDM", - "light_dialogue_pdm:LightDialoguePDM", + # "kajiwoto_vdm:KajiwotoVDM", + # "light_dialogue_pdm:LightDialoguePDM", ] DEFAULT_MODULES_STRING = ",".join(DEFAULT_MODULE_LIST) def main() -> None: + random.seed(42) + parser = argparse.ArgumentParser() parser.add_argument( "-o", @@ -89,9 +93,11 @@ def main() -> None: # Print a newline to visually separate different episodes. if idx != 1: print() - print("---| New Episode |---") - print("---------------------") - print(episode) + + for ep in _episode_augmentations(episode): + print("---| New Episode |---") + print("---------------------") + print("\n---\n".join(ep + [PromptConstants.EOS_TOKEN])) sys.exit() # @@ -118,10 +124,17 @@ def main() -> None: # file. for module in modules: for episode in module(): - if contains_suspect_unicode(episode): - print(f"Found suspect unicode contents in `{episode}`") - json_line = json.dumps({"text": episode}) - output_file.write(f"{json_line}\n") + text = "\n".join(episode) + if contains_suspect_unicode(text): + print( + f"Skipping. Found suspect unicode contents in `{text}`") + continue + + for augmented_episode in _episode_augmentations(episode): + text = "\n".join(augmented_episode + + [PromptConstants.EOS_TOKEN]) + json_line = json.dumps({"text": text}) + output_file.write(f"{json_line}\n") # @@ -129,6 +142,52 @@ def main() -> None: # +def _episode_augmentations( + episode: list[str]) -> t.Generator[list[str], None, None]: + ''' + Generates augmented data for the given episode. + + The first 1.3B model had wildly unpredictable performance at the start of + conversations, which I attributed to the fact that originally we always fed + the model entire episodes to train on, so there were no examples of freshly + started conversations, in a sense. + + This function takes a complete episode and yields different permutations of + it in an attempt to provide that data (e.g. with/without persona, with only + X messages in the history, X+2, X+4 and so on). + ''' + permutated_episode = [] + offset_idx = 0 + + # Don't discard the original episode. + yield episode + + for turn in episode: + if "'s Persona: " in turn or "Scenario: " in turn or PromptConstants.CHAT_START_TOKEN in turn: + permutated_episode.append(turn.strip()) + offset_idx += 1 + continue + + while len(episode) > 1 + offset_idx: + permutated_episode.append(episode.pop(offset_idx)) + permutated_episode.append(episode.pop(offset_idx)) + + # Yielding every single instance results in too much data + # repetition, so instead we take a random sample. + should_yield = random.randint(0, 100) < 25 + if should_yield: + yield permutated_episode + + # Also, yield a version with _just_ dialogue if we've been yielding + # with persona/scenario data this entire time. + if offset_idx == 0: + continue + + should_yield = random.randint(0, 100) < 25 + if should_yield: + yield permutated_episode[offset_idx:] + + def _get_git_revision_short_hash() -> str: '''Returns the project's short git revision hash.''' return subprocess.check_output(