feat: alternative way of handling and augmenting episode data (wip)

This commit is contained in:
11b 2023-01-04 09:05:51 -03:00
parent 46a552ad28
commit 5e34b105dc
7 changed files with 92 additions and 25 deletions

View File

@ -10,9 +10,16 @@ class PromptConstants:
# Token to be replaced by the bot's name.
BOT_TOKEN = "<BOT>"
# Should be kept in sync with the relevant model that will be trained. This
# is taken from EleutherAI's Pythia (so, GPT-NeoX).
EOS_TOKEN = "<|endoftext|>"
# Token to separate prompt trickery from actual dialogue.
CHAT_START_TOKEN = "<START>"
# Global target word count. The word count is chosen in such a way that we
# can fit all the required prompt trickery into the model's input, but still
# leave enough space for the user's input message and the infernce result.
# leave enough space for the user's input message and the inference result.
TARGET_WORD_COUNT_PER_EPISODE = 1024
@staticmethod

View File

@ -4,13 +4,13 @@ import typing as t
class BaseModule:
'''Base module class.'''
def __iter__(self) -> t.Generator[str, None, None]:
def __iter__(self) -> t.Generator[list[str], None, None]:
'''Implements the basic iterator interface.'''
return self.generator()
def generator(self) -> t.Generator[str, None, None]:
def generator(self) -> t.Generator[list[str], None, None]:
'''
Should yield strings that will be used in the model's training /
Should yield dialogue turns that will be used in the model's training /
validation / test splits.
'''
raise NotImplementedError

View File

@ -39,7 +39,7 @@ EPISODE_SIMILARITY_THRESHOLD = 0.55
class CharacterAiPDM(BaseModule):
'''A Persona Dialogue Module powered by CharacterAI data.'''
def generator(self) -> t.Generator[str, None, None]:
def generator(self) -> t.Generator[list[str], None, None]:
for chat in CharacterAiDataset():
if len(chat.messages) < MIN_EPISODE_LEN:
logger.debug(
@ -58,10 +58,10 @@ class CharacterAiPDM(BaseModule):
chat.bot.name, chat.bot.definitions)
base_turns.append(parsed_definitions)
# Add an empty turn to separate persona info from messages, if
# Add turn to separate persona info from messages, if
# necessary.
if len(base_turns) > 0:
base_turns.append("")
base_turns.append(PromptConstants.CHAT_START_TOKEN)
# Now, start adding messages and break episodes apart if they get
# too big.
@ -101,7 +101,8 @@ class CharacterAiPDM(BaseModule):
# target word count, so we return the episode without it...
removed_turn = turns.pop()
if average_similarity_score_for_episode <= EPISODE_SIMILARITY_THRESHOLD:
yield "\n".join(turns)
# yield "\n".join(turns)
yield turns
else:
logger.debug(
"Ignoring episode due to high similarity between messages (%s > %s)",

View File

@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
class DiscordVDM(BaseModule):
'''A Vanilla Dialogue Module powered by Discord dumps.'''
def generator(self) -> t.Generator[str, None, None]:
def generator(self) -> t.Generator[list[str], None, None]:
root_data_path = get_data_path("discord")
db_path = os.path.join(root_data_path, "archive.dht")
db = sqlite3.connect(db_path)
@ -62,7 +62,7 @@ class DiscordVDM(BaseModule):
avg)
continue
yield "\n".join(turns)
yield turns
#

View File

@ -9,7 +9,7 @@ from waifu.modules import BaseModule
class KajiwotoVDM(BaseModule):
'''A Vanilla Dialogue Module powered by the Kajiwoto dataset.'''
def generator(self) -> t.Generator[str, None, None]:
def generator(self) -> t.Generator[list[str], None, None]:
dataset = KajiwotoDataset()
for episode in dataset:
turns: t.List[str] = []
@ -23,4 +23,4 @@ class KajiwotoVDM(BaseModule):
processed_string = replace_special_tokens_in(string)
for generated_string in generate_variants_for(processed_string):
yield generated_string
yield generated_string.split("\n")

View File

@ -9,10 +9,10 @@ from waifu.utils.strings import normalize_string, title_case
class LightDialoguePDM(BaseModule):
'''Persona Dialogue Module based on the LIGHT dataset.'''
def generator(self) -> t.Generator[str, None, None]:
def generator(self) -> t.Generator[list[str], None, None]:
for episode in LightDialogueDataset():
# TODO(11b): Scenario doesn't belong in a persona dialog module.
context_message = f"Context: {episode.context[0]}\n"
context_message = f"Scenario: {episode.context[0]}\n"
persona_message = ""
for agent in episode.agents:
@ -48,4 +48,4 @@ class LightDialoguePDM(BaseModule):
episode_messages.append(message)
yield "\n".join(episode_messages)
yield episode_messages

View File

@ -5,10 +5,12 @@ import importlib
import json
import logging
import os
import random
import subprocess
import sys
import typing as t
from waifu.core.consts import PromptConstants
from waifu.modules import BaseModule
from waifu.utils.strings import contains_suspect_unicode
@ -16,17 +18,19 @@ from waifu.utils.strings import contains_suspect_unicode
# metaprogramming trickery to build this list out instead.
DEFAULT_MODULE_LIST = [
"characterai_pdm:CharacterAiPDM",
"discord_vdm:DiscordVDM",
# "discord_vdm:DiscordVDM",
# KajiwotoPDM has a bunch of garbage I need to filter, disabling in favor
# of the vanilla dialogue module for now.
# "kajiwoto_pdm:KajiwotoPDM",
"kajiwoto_vdm:KajiwotoVDM",
"light_dialogue_pdm:LightDialoguePDM",
# "kajiwoto_vdm:KajiwotoVDM",
# "light_dialogue_pdm:LightDialoguePDM",
]
DEFAULT_MODULES_STRING = ",".join(DEFAULT_MODULE_LIST)
def main() -> None:
random.seed(42)
parser = argparse.ArgumentParser()
parser.add_argument(
"-o",
@ -89,9 +93,11 @@ def main() -> None:
# Print a newline to visually separate different episodes.
if idx != 1:
print()
print("---| New Episode |---")
print("---------------------")
print(episode)
for ep in _episode_augmentations(episode):
print("---| New Episode |---")
print("---------------------")
print("\n---\n".join(ep + [PromptConstants.EOS_TOKEN]))
sys.exit()
#
@ -118,10 +124,17 @@ def main() -> None:
# file.
for module in modules:
for episode in module():
if contains_suspect_unicode(episode):
print(f"Found suspect unicode contents in `{episode}`")
json_line = json.dumps({"text": episode})
output_file.write(f"{json_line}\n")
text = "\n".join(episode)
if contains_suspect_unicode(text):
print(
f"Skipping. Found suspect unicode contents in `{text}`")
continue
for augmented_episode in _episode_augmentations(episode):
text = "\n".join(augmented_episode +
[PromptConstants.EOS_TOKEN])
json_line = json.dumps({"text": text})
output_file.write(f"{json_line}\n")
#
@ -129,6 +142,52 @@ def main() -> None:
#
def _episode_augmentations(
episode: list[str]) -> t.Generator[list[str], None, None]:
'''
Generates augmented data for the given episode.
The first 1.3B model had wildly unpredictable performance at the start of
conversations, which I attributed to the fact that originally we always fed
the model entire episodes to train on, so there were no examples of freshly
started conversations, in a sense.
This function takes a complete episode and yields different permutations of
it in an attempt to provide that data (e.g. with/without persona, with only
X messages in the history, X+2, X+4 and so on).
'''
permutated_episode = []
offset_idx = 0
# Don't discard the original episode.
yield episode
for turn in episode:
if "'s Persona: " in turn or "Scenario: " in turn or PromptConstants.CHAT_START_TOKEN in turn:
permutated_episode.append(turn.strip())
offset_idx += 1
continue
while len(episode) > 1 + offset_idx:
permutated_episode.append(episode.pop(offset_idx))
permutated_episode.append(episode.pop(offset_idx))
# Yielding every single instance results in too much data
# repetition, so instead we take a random sample.
should_yield = random.randint(0, 100) < 25
if should_yield:
yield permutated_episode
# Also, yield a version with _just_ dialogue if we've been yielding
# with persona/scenario data this entire time.
if offset_idx == 0:
continue
should_yield = random.randint(0, 100) < 25
if should_yield:
yield permutated_episode[offset_idx:]
def _get_git_revision_short_hash() -> str:
'''Returns the project's short git revision hash.'''
return subprocess.check_output(