feat: alternative way of handling and augmenting episode data (wip)
This commit is contained in:
parent
46a552ad28
commit
5e34b105dc
|
@ -10,9 +10,16 @@ class PromptConstants:
|
|||
# Token to be replaced by the bot's name.
|
||||
BOT_TOKEN = "<BOT>"
|
||||
|
||||
# Should be kept in sync with the relevant model that will be trained. This
|
||||
# is taken from EleutherAI's Pythia (so, GPT-NeoX).
|
||||
EOS_TOKEN = "<|endoftext|>"
|
||||
|
||||
# Token to separate prompt trickery from actual dialogue.
|
||||
CHAT_START_TOKEN = "<START>"
|
||||
|
||||
# Global target word count. The word count is chosen in such a way that we
|
||||
# can fit all the required prompt trickery into the model's input, but still
|
||||
# leave enough space for the user's input message and the infernce result.
|
||||
# leave enough space for the user's input message and the inference result.
|
||||
TARGET_WORD_COUNT_PER_EPISODE = 1024
|
||||
|
||||
@staticmethod
|
||||
|
|
|
@ -4,13 +4,13 @@ import typing as t
|
|||
class BaseModule:
|
||||
'''Base module class.'''
|
||||
|
||||
def __iter__(self) -> t.Generator[str, None, None]:
|
||||
def __iter__(self) -> t.Generator[list[str], None, None]:
|
||||
'''Implements the basic iterator interface.'''
|
||||
return self.generator()
|
||||
|
||||
def generator(self) -> t.Generator[str, None, None]:
|
||||
def generator(self) -> t.Generator[list[str], None, None]:
|
||||
'''
|
||||
Should yield strings that will be used in the model's training /
|
||||
Should yield dialogue turns that will be used in the model's training /
|
||||
validation / test splits.
|
||||
'''
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -39,7 +39,7 @@ EPISODE_SIMILARITY_THRESHOLD = 0.55
|
|||
class CharacterAiPDM(BaseModule):
|
||||
'''A Persona Dialogue Module powered by CharacterAI data.'''
|
||||
|
||||
def generator(self) -> t.Generator[str, None, None]:
|
||||
def generator(self) -> t.Generator[list[str], None, None]:
|
||||
for chat in CharacterAiDataset():
|
||||
if len(chat.messages) < MIN_EPISODE_LEN:
|
||||
logger.debug(
|
||||
|
@ -58,10 +58,10 @@ class CharacterAiPDM(BaseModule):
|
|||
chat.bot.name, chat.bot.definitions)
|
||||
base_turns.append(parsed_definitions)
|
||||
|
||||
# Add an empty turn to separate persona info from messages, if
|
||||
# Add turn to separate persona info from messages, if
|
||||
# necessary.
|
||||
if len(base_turns) > 0:
|
||||
base_turns.append("")
|
||||
base_turns.append(PromptConstants.CHAT_START_TOKEN)
|
||||
|
||||
# Now, start adding messages and break episodes apart if they get
|
||||
# too big.
|
||||
|
@ -101,7 +101,8 @@ class CharacterAiPDM(BaseModule):
|
|||
# target word count, so we return the episode without it...
|
||||
removed_turn = turns.pop()
|
||||
if average_similarity_score_for_episode <= EPISODE_SIMILARITY_THRESHOLD:
|
||||
yield "\n".join(turns)
|
||||
# yield "\n".join(turns)
|
||||
yield turns
|
||||
else:
|
||||
logger.debug(
|
||||
"Ignoring episode due to high similarity between messages (%s > %s)",
|
||||
|
|
|
@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
|
|||
class DiscordVDM(BaseModule):
|
||||
'''A Vanilla Dialogue Module powered by Discord dumps.'''
|
||||
|
||||
def generator(self) -> t.Generator[str, None, None]:
|
||||
def generator(self) -> t.Generator[list[str], None, None]:
|
||||
root_data_path = get_data_path("discord")
|
||||
db_path = os.path.join(root_data_path, "archive.dht")
|
||||
db = sqlite3.connect(db_path)
|
||||
|
@ -62,7 +62,7 @@ class DiscordVDM(BaseModule):
|
|||
avg)
|
||||
continue
|
||||
|
||||
yield "\n".join(turns)
|
||||
yield turns
|
||||
|
||||
|
||||
#
|
||||
|
|
|
@ -9,7 +9,7 @@ from waifu.modules import BaseModule
|
|||
class KajiwotoVDM(BaseModule):
|
||||
'''A Vanilla Dialogue Module powered by the Kajiwoto dataset.'''
|
||||
|
||||
def generator(self) -> t.Generator[str, None, None]:
|
||||
def generator(self) -> t.Generator[list[str], None, None]:
|
||||
dataset = KajiwotoDataset()
|
||||
for episode in dataset:
|
||||
turns: t.List[str] = []
|
||||
|
@ -23,4 +23,4 @@ class KajiwotoVDM(BaseModule):
|
|||
processed_string = replace_special_tokens_in(string)
|
||||
|
||||
for generated_string in generate_variants_for(processed_string):
|
||||
yield generated_string
|
||||
yield generated_string.split("\n")
|
||||
|
|
|
@ -9,10 +9,10 @@ from waifu.utils.strings import normalize_string, title_case
|
|||
class LightDialoguePDM(BaseModule):
|
||||
'''Persona Dialogue Module based on the LIGHT dataset.'''
|
||||
|
||||
def generator(self) -> t.Generator[str, None, None]:
|
||||
def generator(self) -> t.Generator[list[str], None, None]:
|
||||
for episode in LightDialogueDataset():
|
||||
# TODO(11b): Scenario doesn't belong in a persona dialog module.
|
||||
context_message = f"Context: {episode.context[0]}\n"
|
||||
context_message = f"Scenario: {episode.context[0]}\n"
|
||||
|
||||
persona_message = ""
|
||||
for agent in episode.agents:
|
||||
|
@ -48,4 +48,4 @@ class LightDialoguePDM(BaseModule):
|
|||
|
||||
episode_messages.append(message)
|
||||
|
||||
yield "\n".join(episode_messages)
|
||||
yield episode_messages
|
||||
|
|
|
@ -5,10 +5,12 @@ import importlib
|
|||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
import sys
|
||||
import typing as t
|
||||
|
||||
from waifu.core.consts import PromptConstants
|
||||
from waifu.modules import BaseModule
|
||||
from waifu.utils.strings import contains_suspect_unicode
|
||||
|
||||
|
@ -16,17 +18,19 @@ from waifu.utils.strings import contains_suspect_unicode
|
|||
# metaprogramming trickery to build this list out instead.
|
||||
DEFAULT_MODULE_LIST = [
|
||||
"characterai_pdm:CharacterAiPDM",
|
||||
"discord_vdm:DiscordVDM",
|
||||
# "discord_vdm:DiscordVDM",
|
||||
# KajiwotoPDM has a bunch of garbage I need to filter, disabling in favor
|
||||
# of the vanilla dialogue module for now.
|
||||
# "kajiwoto_pdm:KajiwotoPDM",
|
||||
"kajiwoto_vdm:KajiwotoVDM",
|
||||
"light_dialogue_pdm:LightDialoguePDM",
|
||||
# "kajiwoto_vdm:KajiwotoVDM",
|
||||
# "light_dialogue_pdm:LightDialoguePDM",
|
||||
]
|
||||
DEFAULT_MODULES_STRING = ",".join(DEFAULT_MODULE_LIST)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
random.seed(42)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
|
@ -89,9 +93,11 @@ def main() -> None:
|
|||
# Print a newline to visually separate different episodes.
|
||||
if idx != 1:
|
||||
print()
|
||||
print("---| New Episode |---")
|
||||
print("---------------------")
|
||||
print(episode)
|
||||
|
||||
for ep in _episode_augmentations(episode):
|
||||
print("---| New Episode |---")
|
||||
print("---------------------")
|
||||
print("\n---\n".join(ep + [PromptConstants.EOS_TOKEN]))
|
||||
sys.exit()
|
||||
|
||||
#
|
||||
|
@ -118,10 +124,17 @@ def main() -> None:
|
|||
# file.
|
||||
for module in modules:
|
||||
for episode in module():
|
||||
if contains_suspect_unicode(episode):
|
||||
print(f"Found suspect unicode contents in `{episode}`")
|
||||
json_line = json.dumps({"text": episode})
|
||||
output_file.write(f"{json_line}\n")
|
||||
text = "\n".join(episode)
|
||||
if contains_suspect_unicode(text):
|
||||
print(
|
||||
f"Skipping. Found suspect unicode contents in `{text}`")
|
||||
continue
|
||||
|
||||
for augmented_episode in _episode_augmentations(episode):
|
||||
text = "\n".join(augmented_episode +
|
||||
[PromptConstants.EOS_TOKEN])
|
||||
json_line = json.dumps({"text": text})
|
||||
output_file.write(f"{json_line}\n")
|
||||
|
||||
|
||||
#
|
||||
|
@ -129,6 +142,52 @@ def main() -> None:
|
|||
#
|
||||
|
||||
|
||||
def _episode_augmentations(
|
||||
episode: list[str]) -> t.Generator[list[str], None, None]:
|
||||
'''
|
||||
Generates augmented data for the given episode.
|
||||
|
||||
The first 1.3B model had wildly unpredictable performance at the start of
|
||||
conversations, which I attributed to the fact that originally we always fed
|
||||
the model entire episodes to train on, so there were no examples of freshly
|
||||
started conversations, in a sense.
|
||||
|
||||
This function takes a complete episode and yields different permutations of
|
||||
it in an attempt to provide that data (e.g. with/without persona, with only
|
||||
X messages in the history, X+2, X+4 and so on).
|
||||
'''
|
||||
permutated_episode = []
|
||||
offset_idx = 0
|
||||
|
||||
# Don't discard the original episode.
|
||||
yield episode
|
||||
|
||||
for turn in episode:
|
||||
if "'s Persona: " in turn or "Scenario: " in turn or PromptConstants.CHAT_START_TOKEN in turn:
|
||||
permutated_episode.append(turn.strip())
|
||||
offset_idx += 1
|
||||
continue
|
||||
|
||||
while len(episode) > 1 + offset_idx:
|
||||
permutated_episode.append(episode.pop(offset_idx))
|
||||
permutated_episode.append(episode.pop(offset_idx))
|
||||
|
||||
# Yielding every single instance results in too much data
|
||||
# repetition, so instead we take a random sample.
|
||||
should_yield = random.randint(0, 100) < 25
|
||||
if should_yield:
|
||||
yield permutated_episode
|
||||
|
||||
# Also, yield a version with _just_ dialogue if we've been yielding
|
||||
# with persona/scenario data this entire time.
|
||||
if offset_idx == 0:
|
||||
continue
|
||||
|
||||
should_yield = random.randint(0, 100) < 25
|
||||
if should_yield:
|
||||
yield permutated_episode[offset_idx:]
|
||||
|
||||
|
||||
def _get_git_revision_short_hash() -> str:
|
||||
'''Returns the project's short git revision hash.'''
|
||||
return subprocess.check_output(
|
||||
|
|
Loading…
Reference in New Issue