toolbox/toolbox/scripts/build_dataset.py

218 lines
7.3 KiB
Python
Executable File

#!/usr/bin/env python3
import argparse
import hashlib
import importlib
import json
import logging
import os
import random
import subprocess
import sys
import typing as t
from toolbox.core.consts import PromptConstants
from toolbox.modules import BaseModule
from toolbox.utils.strings import contains_suspect_unicode
# TODO(11b): Needs manual maintenance to keep up-to-date. Consider doing some
# metaprogramming trickery to build this list out instead.
DEFAULT_MODULE_LIST = [
"characterai_pdm:CharacterAiPDM",
# "discord_vdm:DiscordVDM",
# KajiwotoPDM has a bunch of garbage I need to filter, disabling in favor
# of the vanilla dialogue module for now.
# "kajiwoto_pdm:KajiwotoPDM",
# "kajiwoto_vdm:KajiwotoVDM",
# "light_dialogue_pdm:LightDialoguePDM",
]
DEFAULT_MODULES_STRING = ",".join(DEFAULT_MODULE_LIST)
def main() -> None:
random.seed(42)
parser = argparse.ArgumentParser()
parser.add_argument(
"-o",
"--output-name",
help="Path to write to. Should not include a file extension.")
parser.add_argument("-m",
"--modules",
default=DEFAULT_MODULES_STRING,
help="List of modules to use, comma-separated.")
parser.add_argument(
"-p",
"--print",
type=int,
help="If given, print this many episodes instead of writing to a file.")
parser.add_argument(
"-s",
"--skip",
type=int,
help="If given, skip over this many episodes before printing.")
parser.add_argument("-v",
"--verbose",
action="store_true",
help="Enable verbose logging.")
args = parser.parse_args()
logging.basicConfig(
format='[%(asctime)s] [%(levelname)s] %(message)s',
level=logging.DEBUG if args.verbose else logging.INFO,
)
# Sanity checks.
if args.output_name and args.print:
raise Exception("--output-name and --print are mutually exclusive.")
if args.skip and not args.print:
raise Exception("--skip can only be used in conjunction with --print.")
modules = _import_modules_from_string(args.modules)
#
# If the print argument was specified, print and exit.
#
if args.print:
idx = 0
episodes_to_skip = args.skip if args.skip is not None else None
for module in modules:
for episode in module():
if episodes_to_skip:
episodes_to_skip -= 1
continue
idx += 1
if idx > args.print:
sys.exit()
# Print a newline to visually separate different episodes.
if idx != 1:
print()
for ep in _episode_augmentations(episode):
print("---| New Episode |---")
print("---------------------")
print("\n---\n".join(ep + [PromptConstants.EOS_TOKEN]))
sys.exit()
#
# Otherwise, proceed with the writing logic.
#
# If no output name is given, we build one from the current git revision
# plus a hash of the given arguments. That way, the same dataset should
# theoretically always have the same output name, which is helpful for
# reproducibility and bailing out early (e.g. if the file already exists).
if args.output_name is None:
args_hash = hashlib.sha256(str(args).encode("utf-8")).hexdigest()[:7]
output_name = f"rev-{_get_git_revision_short_hash()}-args-{args_hash}"
else:
output_name = args.output_name
# Open the output file.
output_filename = f"{output_name}.jsonl"
if os.path.exists(output_filename):
raise Exception(f"{output_filename} already exists, aborting.")
with open(output_filename, "w", encoding="utf-8") as output_file:
# Iterate over each module sequentially, and write the data out into the
# file.
for module in modules:
for episode in module():
text = "\n".join(episode)
if contains_suspect_unicode(text):
print(
f"Skipping. Found suspect unicode contents in `{text}`")
continue
for augmented_episode in _episode_augmentations(episode):
text = "\n".join(augmented_episode +
[PromptConstants.EOS_TOKEN])
json_line = json.dumps({"text": text})
output_file.write(f"{json_line}\n")
#
# Helpers and CLI entrypoint.
#
def _episode_augmentations(
episode: list[str]) -> t.Generator[list[str], None, None]:
'''
Generates augmented data for the given episode.
The first 1.3B model had wildly unpredictable performance at the start of
conversations, which I attributed to the fact that originally we always fed
the model entire episodes to train on, so there were no examples of freshly
started conversations, in a sense.
This function takes a complete episode and yields different permutations of
it in an attempt to provide that data (e.g. with/without persona, with only
X messages in the history, X+2, X+4 and so on).
'''
permutated_episode = []
offset_idx = 0
# Don't discard the original episode.
yield episode
for turn in episode:
if "'s Persona: " in turn or "Scenario: " in turn or PromptConstants.CHAT_START_TOKEN in turn:
permutated_episode.append(turn.strip())
offset_idx += 1
continue
while len(episode) > 1 + offset_idx:
permutated_episode.append(episode.pop(offset_idx))
permutated_episode.append(episode.pop(offset_idx))
# Yielding every single instance results in too much data
# repetition, so instead we take a random sample.
should_yield = random.randint(0, 100) < 25
if should_yield:
yield permutated_episode
# Also, yield a version with _just_ dialogue if we've been yielding
# with persona/scenario data this entire time.
if offset_idx == 0:
continue
should_yield = random.randint(0, 100) < 25
if should_yield:
yield permutated_episode[offset_idx:]
def _get_git_revision_short_hash() -> str:
'''Returns the project's short git revision hash.'''
return subprocess.check_output(
["git", "rev-parse", "--short", "HEAD"],
cwd=os.path.join(os.path.dirname(os.path.realpath(__file__)), "..",
"..")).decode("ascii").strip()
def _import_modules_from_string(string: str) -> t.List[t.Type[BaseModule]]:
'''Imports all the module classes from the given, comma-separated string.'''
modules: t.List[t.Type[BaseModule]] = []
for module_and_class_name in string.split(","):
qualified_module_name = "toolbox.modules"
try:
module_name, class_name = module_and_class_name.split(":")
qualified_module_name = f"toolbox.modules.{module_name}"
except ValueError:
class_name = module_and_class_name
module = importlib.import_module(qualified_module_name)
modules.append(getattr(module, class_name))
return modules
if __name__ == "__main__":
main()