From 6fbd660a6740440b1a6066a7276f2b1278cb69b9 Mon Sep 17 00:00:00 2001 From: 0x000011b <0x000011b@waifu.club> Date: Sat, 17 Dec 2022 21:37:27 -0300 Subject: [PATCH] feat: implement script to build final data file --- README.md | 61 +++++++++++++++ waifu/scripts/build_dataset.py | 136 +++++++++++++++++++++++++++++++++ 2 files changed, 197 insertions(+) create mode 100755 waifu/scripts/build_dataset.py diff --git a/README.md b/README.md index f59cd88..01df31b 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,8 @@ **Note**: This is a _very_ early work-in-progress. Expect the unexpected. +## Summary + As of the moment I'm writing this, the roadmap for the project's prototype model is basically: - Build a dataset @@ -20,3 +22,62 @@ In short, here's how it works: - These are heavily inspired by the papers that introduced LaMDA and BlenderBot3 (and their relevant supporting papers as well). - In general, each module is responsible for using a dataset as an input, and processing that data down into text that will be used in the fine-tuning process. - A final data file is produced by concatenating the outputs of all the modules. This file is used as an input for the fine-tuning process. + +Here's how I do that: + +## Building the data file(s) + +The final data file is created with the [build_dataset.py](/waifu/scripts/build_dataset.py) script: + +``` +$ ./waifu/scripts/build_dataset.py --help +usage: build_dataset.py [-h] [-o OUTPUT_NAME] [-m MODULES] [-p PRINT] [-v] + +options: + -h, --help show this help message and exit + -o OUTPUT_NAME, --output-name OUTPUT_NAME + File to write the dataset to. Should not include a file extension. + -m MODULES, --modules MODULES + List of modules to use, comma-separated. + -p PRINT, --print PRINT + If given, print this many episodes instead of writing out to a file. + -v, --verbose Enable verbose logging. +``` + +The default behavior is to write a file called `rev-{GIT_REVISION_HASH}-args{HASH_OF_USED_ARGS}.jsonl` to the current directory, with all the modules enabled. Behavior can be customized via the flags shown above. + +The script also has an option to print some examples instead of writing to a file, for debugging/dev purposes. Example usage: + +```bash +$ ./waifu/scripts/build_dataset.py --print 1 --modules 'light_dialogue_vdm:LightDialogueVDM' # or -p 1 and -m ... +``` + +Example output: + +``` +--- new episode --- +Context: You are in the Watchtower. +The tower is the largest section of the castle. It contains an observatory for nighttime scouting, but is also used by the wise men to study the stars. Armed +guardsmen are always to be found keeping watch. +There's an alarm horn here. +A soldier is here. You are carrying nothing. + +Court Wizard: A quiet night this evening... +Soldier: Yes it is +Court Wizard: *ponder* Have any else come up this eve? I had hoped for a quiet night to examine the stars +Soldier: *nod* Yes, a few came through, but it is a cold night for me, I am used to warmer weather +Court Wizard: *sigh* Well, you are but a common soldier. No doubt you are used to such a lot. Thankfully I have my spells to keep me warm. +Soldier: *grin* I am a soldier doing my job +Court Wizard: Yes... Well... Very well then. See that you do! No slacking off while your betters are about. +Soldier: No sir +Court Wizard: When, for example, was this horn last tested? It looks dented. How can we be sure it will work? +Soldier: A year ago, test it out or cause a need to use it +Court Wizard: *frown* Mayhap I will speak to the king about such lackness. Or perhaps I can sell him a spell that will serve just as well. +Soldier: Good idea, I agree, go do that *hug court wizard* +Court Wizard: Get off of me, you fool! Who gave you permission to touch me! *hit soldier* +Soldier: To the jail with you *hit court wizard* +``` + +## Fine-tuning a model + +To-do. I haven't documented this yet. diff --git a/waifu/scripts/build_dataset.py b/waifu/scripts/build_dataset.py new file mode 100755 index 0000000..46a6e99 --- /dev/null +++ b/waifu/scripts/build_dataset.py @@ -0,0 +1,136 @@ +#!/usr/bin/env python3 +import argparse +import hashlib +import importlib +import json +import os +import subprocess +import sys +import typing as t +import logging + +from waifu.modules import BaseModule + +# TODO(11b): Needs manual maintenance ot keep up-to-date. Consider doing some +# metaprogramming trickery to build this list out instead. +DEFAULT_MODULE_LIST = [ + "kajiwoto_pdm:KajiwotoPDM", + "kajiwoto_vdm:KajiwotoVDM", + "light_dialogue_vdm:LightDialogueVDM", +] +DEFAULT_MODULES_STRING = ",".join(DEFAULT_MODULE_LIST) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "-o", + "--output-name", + help="Path to write to. Should not include a file extension.") + + parser.add_argument("-m", + "--modules", + default=DEFAULT_MODULES_STRING, + help="List of modules to use, comma-separated.") + + parser.add_argument( + "-p", + "--print", + type=int, + help="If given, print this many episodes instead of writing to a file.") + + parser.add_argument("-v", + "--verbose", + action="store_true", + help="Enable verbose logging.") + + args = parser.parse_args() + + logging.basicConfig( + format='[%(asctime)s] [%(levelname)s] %(message)s', + level=logging.DEBUG if args.verbose else logging.INFO, + ) + + # Sanity check. + if args.output_name and args.print: + raise Exception("--output-name and --print are mutually exclusive.") + + modules = _import_modules_from_string(args.modules) + + # + # If the print argument was specified, print and exit. + # + if args.print: + idx = 0 + for module in modules: + for episode in module(): + idx += 1 + if idx > args.print: + sys.exit() + + # Print a newline to visually separate different episodes. + if idx != 1: + print() + print("--- new episode ---") + print(episode) + + # + # Otherwise, proceed with the writing logic. + # + + # If no output name is given, we build one from the current git revision + # plus a hash of the given arguments. That way, the same dataset should + # theoretically always have the same output name, which is helpful for + # reproducibility and bailing out early (e.g. if the file already exists). + if args.output_name is None: + args_hash = hashlib.sha256(str(args).encode("utf-8")).hexdigest()[:7] + output_name = f"rev-{_get_git_revision_short_hash()}-args-{args_hash}" + else: + output_name = args.output_name + + # Open the output file. + output_filename = f"{output_name}.jsonl" + if os.path.exists(output_filename): + raise Exception(f"{output_filename} already exists, aborting.") + + with open(output_filename, "w", encoding="utf-8") as output_file: + # Iterate over each module sequentially, and write the data out into the + # file. + for module in modules: + for episode in module(): + json_line = json.dumps({"text": episode}) + output_file.write(f"{json_line}\n") + + +# +# Helpers and CLI entrypoint. +# + + +def _get_git_revision_short_hash() -> str: + '''Returns the project's short git revision hash.''' + return subprocess.check_output( + ["git", "rev-parse", "--short", "HEAD"], + cwd=os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", + "..")).decode("ascii").strip() + + +def _import_modules_from_string(string: str) -> t.List[t.Type[BaseModule]]: + '''Imports all the module classes from the given, comma-separated string.''' + modules: t.List[t.Type[BaseModule]] = [] + for module_and_class_name in string.split(","): + qualified_module_name = "waifu.modules" + try: + module_name, class_name = module_and_class_name.split(":") + qualified_module_name = f"waifu.modules.{module_name}" + except ValueError: + class_name = module_and_class_name + + module = importlib.import_module(qualified_module_name) + modules.append(getattr(module, class_name)) + + return modules + + +if __name__ == "__main__": + main()