feat: alternative way of handling and augmenting episode data (wip)
This commit is contained in:
parent
46a552ad28
commit
5e34b105dc
|
@ -10,9 +10,16 @@ class PromptConstants:
|
||||||
# Token to be replaced by the bot's name.
|
# Token to be replaced by the bot's name.
|
||||||
BOT_TOKEN = "<BOT>"
|
BOT_TOKEN = "<BOT>"
|
||||||
|
|
||||||
|
# Should be kept in sync with the relevant model that will be trained. This
|
||||||
|
# is taken from EleutherAI's Pythia (so, GPT-NeoX).
|
||||||
|
EOS_TOKEN = "<|endoftext|>"
|
||||||
|
|
||||||
|
# Token to separate prompt trickery from actual dialogue.
|
||||||
|
CHAT_START_TOKEN = "<START>"
|
||||||
|
|
||||||
# Global target word count. The word count is chosen in such a way that we
|
# Global target word count. The word count is chosen in such a way that we
|
||||||
# can fit all the required prompt trickery into the model's input, but still
|
# can fit all the required prompt trickery into the model's input, but still
|
||||||
# leave enough space for the user's input message and the infernce result.
|
# leave enough space for the user's input message and the inference result.
|
||||||
TARGET_WORD_COUNT_PER_EPISODE = 1024
|
TARGET_WORD_COUNT_PER_EPISODE = 1024
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -4,13 +4,13 @@ import typing as t
|
||||||
class BaseModule:
|
class BaseModule:
|
||||||
'''Base module class.'''
|
'''Base module class.'''
|
||||||
|
|
||||||
def __iter__(self) -> t.Generator[str, None, None]:
|
def __iter__(self) -> t.Generator[list[str], None, None]:
|
||||||
'''Implements the basic iterator interface.'''
|
'''Implements the basic iterator interface.'''
|
||||||
return self.generator()
|
return self.generator()
|
||||||
|
|
||||||
def generator(self) -> t.Generator[str, None, None]:
|
def generator(self) -> t.Generator[list[str], None, None]:
|
||||||
'''
|
'''
|
||||||
Should yield strings that will be used in the model's training /
|
Should yield dialogue turns that will be used in the model's training /
|
||||||
validation / test splits.
|
validation / test splits.
|
||||||
'''
|
'''
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
@ -39,7 +39,7 @@ EPISODE_SIMILARITY_THRESHOLD = 0.55
|
||||||
class CharacterAiPDM(BaseModule):
|
class CharacterAiPDM(BaseModule):
|
||||||
'''A Persona Dialogue Module powered by CharacterAI data.'''
|
'''A Persona Dialogue Module powered by CharacterAI data.'''
|
||||||
|
|
||||||
def generator(self) -> t.Generator[str, None, None]:
|
def generator(self) -> t.Generator[list[str], None, None]:
|
||||||
for chat in CharacterAiDataset():
|
for chat in CharacterAiDataset():
|
||||||
if len(chat.messages) < MIN_EPISODE_LEN:
|
if len(chat.messages) < MIN_EPISODE_LEN:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
@ -58,10 +58,10 @@ class CharacterAiPDM(BaseModule):
|
||||||
chat.bot.name, chat.bot.definitions)
|
chat.bot.name, chat.bot.definitions)
|
||||||
base_turns.append(parsed_definitions)
|
base_turns.append(parsed_definitions)
|
||||||
|
|
||||||
# Add an empty turn to separate persona info from messages, if
|
# Add turn to separate persona info from messages, if
|
||||||
# necessary.
|
# necessary.
|
||||||
if len(base_turns) > 0:
|
if len(base_turns) > 0:
|
||||||
base_turns.append("")
|
base_turns.append(PromptConstants.CHAT_START_TOKEN)
|
||||||
|
|
||||||
# Now, start adding messages and break episodes apart if they get
|
# Now, start adding messages and break episodes apart if they get
|
||||||
# too big.
|
# too big.
|
||||||
|
@ -101,7 +101,8 @@ class CharacterAiPDM(BaseModule):
|
||||||
# target word count, so we return the episode without it...
|
# target word count, so we return the episode without it...
|
||||||
removed_turn = turns.pop()
|
removed_turn = turns.pop()
|
||||||
if average_similarity_score_for_episode <= EPISODE_SIMILARITY_THRESHOLD:
|
if average_similarity_score_for_episode <= EPISODE_SIMILARITY_THRESHOLD:
|
||||||
yield "\n".join(turns)
|
# yield "\n".join(turns)
|
||||||
|
yield turns
|
||||||
else:
|
else:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Ignoring episode due to high similarity between messages (%s > %s)",
|
"Ignoring episode due to high similarity between messages (%s > %s)",
|
||||||
|
|
|
@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
|
||||||
class DiscordVDM(BaseModule):
|
class DiscordVDM(BaseModule):
|
||||||
'''A Vanilla Dialogue Module powered by Discord dumps.'''
|
'''A Vanilla Dialogue Module powered by Discord dumps.'''
|
||||||
|
|
||||||
def generator(self) -> t.Generator[str, None, None]:
|
def generator(self) -> t.Generator[list[str], None, None]:
|
||||||
root_data_path = get_data_path("discord")
|
root_data_path = get_data_path("discord")
|
||||||
db_path = os.path.join(root_data_path, "archive.dht")
|
db_path = os.path.join(root_data_path, "archive.dht")
|
||||||
db = sqlite3.connect(db_path)
|
db = sqlite3.connect(db_path)
|
||||||
|
@ -62,7 +62,7 @@ class DiscordVDM(BaseModule):
|
||||||
avg)
|
avg)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
yield "\n".join(turns)
|
yield turns
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|
|
@ -9,7 +9,7 @@ from waifu.modules import BaseModule
|
||||||
class KajiwotoVDM(BaseModule):
|
class KajiwotoVDM(BaseModule):
|
||||||
'''A Vanilla Dialogue Module powered by the Kajiwoto dataset.'''
|
'''A Vanilla Dialogue Module powered by the Kajiwoto dataset.'''
|
||||||
|
|
||||||
def generator(self) -> t.Generator[str, None, None]:
|
def generator(self) -> t.Generator[list[str], None, None]:
|
||||||
dataset = KajiwotoDataset()
|
dataset = KajiwotoDataset()
|
||||||
for episode in dataset:
|
for episode in dataset:
|
||||||
turns: t.List[str] = []
|
turns: t.List[str] = []
|
||||||
|
@ -23,4 +23,4 @@ class KajiwotoVDM(BaseModule):
|
||||||
processed_string = replace_special_tokens_in(string)
|
processed_string = replace_special_tokens_in(string)
|
||||||
|
|
||||||
for generated_string in generate_variants_for(processed_string):
|
for generated_string in generate_variants_for(processed_string):
|
||||||
yield generated_string
|
yield generated_string.split("\n")
|
||||||
|
|
|
@ -9,10 +9,10 @@ from waifu.utils.strings import normalize_string, title_case
|
||||||
class LightDialoguePDM(BaseModule):
|
class LightDialoguePDM(BaseModule):
|
||||||
'''Persona Dialogue Module based on the LIGHT dataset.'''
|
'''Persona Dialogue Module based on the LIGHT dataset.'''
|
||||||
|
|
||||||
def generator(self) -> t.Generator[str, None, None]:
|
def generator(self) -> t.Generator[list[str], None, None]:
|
||||||
for episode in LightDialogueDataset():
|
for episode in LightDialogueDataset():
|
||||||
# TODO(11b): Scenario doesn't belong in a persona dialog module.
|
# TODO(11b): Scenario doesn't belong in a persona dialog module.
|
||||||
context_message = f"Context: {episode.context[0]}\n"
|
context_message = f"Scenario: {episode.context[0]}\n"
|
||||||
|
|
||||||
persona_message = ""
|
persona_message = ""
|
||||||
for agent in episode.agents:
|
for agent in episode.agents:
|
||||||
|
@ -48,4 +48,4 @@ class LightDialoguePDM(BaseModule):
|
||||||
|
|
||||||
episode_messages.append(message)
|
episode_messages.append(message)
|
||||||
|
|
||||||
yield "\n".join(episode_messages)
|
yield episode_messages
|
||||||
|
|
|
@ -5,10 +5,12 @@ import importlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
|
from waifu.core.consts import PromptConstants
|
||||||
from waifu.modules import BaseModule
|
from waifu.modules import BaseModule
|
||||||
from waifu.utils.strings import contains_suspect_unicode
|
from waifu.utils.strings import contains_suspect_unicode
|
||||||
|
|
||||||
|
@ -16,17 +18,19 @@ from waifu.utils.strings import contains_suspect_unicode
|
||||||
# metaprogramming trickery to build this list out instead.
|
# metaprogramming trickery to build this list out instead.
|
||||||
DEFAULT_MODULE_LIST = [
|
DEFAULT_MODULE_LIST = [
|
||||||
"characterai_pdm:CharacterAiPDM",
|
"characterai_pdm:CharacterAiPDM",
|
||||||
"discord_vdm:DiscordVDM",
|
# "discord_vdm:DiscordVDM",
|
||||||
# KajiwotoPDM has a bunch of garbage I need to filter, disabling in favor
|
# KajiwotoPDM has a bunch of garbage I need to filter, disabling in favor
|
||||||
# of the vanilla dialogue module for now.
|
# of the vanilla dialogue module for now.
|
||||||
# "kajiwoto_pdm:KajiwotoPDM",
|
# "kajiwoto_pdm:KajiwotoPDM",
|
||||||
"kajiwoto_vdm:KajiwotoVDM",
|
# "kajiwoto_vdm:KajiwotoVDM",
|
||||||
"light_dialogue_pdm:LightDialoguePDM",
|
# "light_dialogue_pdm:LightDialoguePDM",
|
||||||
]
|
]
|
||||||
DEFAULT_MODULES_STRING = ",".join(DEFAULT_MODULE_LIST)
|
DEFAULT_MODULES_STRING = ",".join(DEFAULT_MODULE_LIST)
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
|
random.seed(42)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-o",
|
"-o",
|
||||||
|
@ -89,9 +93,11 @@ def main() -> None:
|
||||||
# Print a newline to visually separate different episodes.
|
# Print a newline to visually separate different episodes.
|
||||||
if idx != 1:
|
if idx != 1:
|
||||||
print()
|
print()
|
||||||
print("---| New Episode |---")
|
|
||||||
print("---------------------")
|
for ep in _episode_augmentations(episode):
|
||||||
print(episode)
|
print("---| New Episode |---")
|
||||||
|
print("---------------------")
|
||||||
|
print("\n---\n".join(ep + [PromptConstants.EOS_TOKEN]))
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -118,10 +124,17 @@ def main() -> None:
|
||||||
# file.
|
# file.
|
||||||
for module in modules:
|
for module in modules:
|
||||||
for episode in module():
|
for episode in module():
|
||||||
if contains_suspect_unicode(episode):
|
text = "\n".join(episode)
|
||||||
print(f"Found suspect unicode contents in `{episode}`")
|
if contains_suspect_unicode(text):
|
||||||
json_line = json.dumps({"text": episode})
|
print(
|
||||||
output_file.write(f"{json_line}\n")
|
f"Skipping. Found suspect unicode contents in `{text}`")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for augmented_episode in _episode_augmentations(episode):
|
||||||
|
text = "\n".join(augmented_episode +
|
||||||
|
[PromptConstants.EOS_TOKEN])
|
||||||
|
json_line = json.dumps({"text": text})
|
||||||
|
output_file.write(f"{json_line}\n")
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -129,6 +142,52 @@ def main() -> None:
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
|
def _episode_augmentations(
|
||||||
|
episode: list[str]) -> t.Generator[list[str], None, None]:
|
||||||
|
'''
|
||||||
|
Generates augmented data for the given episode.
|
||||||
|
|
||||||
|
The first 1.3B model had wildly unpredictable performance at the start of
|
||||||
|
conversations, which I attributed to the fact that originally we always fed
|
||||||
|
the model entire episodes to train on, so there were no examples of freshly
|
||||||
|
started conversations, in a sense.
|
||||||
|
|
||||||
|
This function takes a complete episode and yields different permutations of
|
||||||
|
it in an attempt to provide that data (e.g. with/without persona, with only
|
||||||
|
X messages in the history, X+2, X+4 and so on).
|
||||||
|
'''
|
||||||
|
permutated_episode = []
|
||||||
|
offset_idx = 0
|
||||||
|
|
||||||
|
# Don't discard the original episode.
|
||||||
|
yield episode
|
||||||
|
|
||||||
|
for turn in episode:
|
||||||
|
if "'s Persona: " in turn or "Scenario: " in turn or PromptConstants.CHAT_START_TOKEN in turn:
|
||||||
|
permutated_episode.append(turn.strip())
|
||||||
|
offset_idx += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
while len(episode) > 1 + offset_idx:
|
||||||
|
permutated_episode.append(episode.pop(offset_idx))
|
||||||
|
permutated_episode.append(episode.pop(offset_idx))
|
||||||
|
|
||||||
|
# Yielding every single instance results in too much data
|
||||||
|
# repetition, so instead we take a random sample.
|
||||||
|
should_yield = random.randint(0, 100) < 25
|
||||||
|
if should_yield:
|
||||||
|
yield permutated_episode
|
||||||
|
|
||||||
|
# Also, yield a version with _just_ dialogue if we've been yielding
|
||||||
|
# with persona/scenario data this entire time.
|
||||||
|
if offset_idx == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
should_yield = random.randint(0, 100) < 25
|
||||||
|
if should_yield:
|
||||||
|
yield permutated_episode[offset_idx:]
|
||||||
|
|
||||||
|
|
||||||
def _get_git_revision_short_hash() -> str:
|
def _get_git_revision_short_hash() -> str:
|
||||||
'''Returns the project's short git revision hash.'''
|
'''Returns the project's short git revision hash.'''
|
||||||
return subprocess.check_output(
|
return subprocess.check_output(
|
||||||
|
|
Loading…
Reference in New Issue