feat: implement script to build final data file

This commit is contained in:
11b 2022-12-17 21:37:27 -03:00
parent 8df2d87355
commit 6fbd660a67
2 changed files with 197 additions and 0 deletions

View File

@ -2,6 +2,8 @@
**Note**: This is a _very_ early work-in-progress. Expect the unexpected.
## Summary
As of the moment I'm writing this, the roadmap for the project's prototype model is basically:
- Build a dataset
@ -20,3 +22,62 @@ In short, here's how it works:
- These are heavily inspired by the papers that introduced LaMDA and BlenderBot3 (and their relevant supporting papers as well).
- In general, each module is responsible for using a dataset as an input, and processing that data down into text that will be used in the fine-tuning process.
- A final data file is produced by concatenating the outputs of all the modules. This file is used as an input for the fine-tuning process.
Here's how I do that:
## Building the data file(s)
The final data file is created with the [build_dataset.py](/waifu/scripts/build_dataset.py) script:
```
$ ./waifu/scripts/build_dataset.py --help
usage: build_dataset.py [-h] [-o OUTPUT_NAME] [-m MODULES] [-p PRINT] [-v]
options:
-h, --help show this help message and exit
-o OUTPUT_NAME, --output-name OUTPUT_NAME
File to write the dataset to. Should not include a file extension.
-m MODULES, --modules MODULES
List of modules to use, comma-separated.
-p PRINT, --print PRINT
If given, print this many episodes instead of writing out to a file.
-v, --verbose Enable verbose logging.
```
The default behavior is to write a file called `rev-{GIT_REVISION_HASH}-args{HASH_OF_USED_ARGS}.jsonl` to the current directory, with all the modules enabled. Behavior can be customized via the flags shown above.
The script also has an option to print some examples instead of writing to a file, for debugging/dev purposes. Example usage:
```bash
$ ./waifu/scripts/build_dataset.py --print 1 --modules 'light_dialogue_vdm:LightDialogueVDM' # or -p 1 and -m ...
```
Example output:
```
--- new episode ---
Context: You are in the Watchtower.
The tower is the largest section of the castle. It contains an observatory for nighttime scouting, but is also used by the wise men to study the stars. Armed
guardsmen are always to be found keeping watch.
There's an alarm horn here.
A soldier is here. You are carrying nothing.
Court Wizard: A quiet night this evening...
Soldier: Yes it is
Court Wizard: *ponder* Have any else come up this eve? I had hoped for a quiet night to examine the stars
Soldier: *nod* Yes, a few came through, but it is a cold night for me, I am used to warmer weather
Court Wizard: *sigh* Well, you are but a common soldier. No doubt you are used to such a lot. Thankfully I have my spells to keep me warm.
Soldier: *grin* I am a soldier doing my job
Court Wizard: Yes... Well... Very well then. See that you do! No slacking off while your betters are about.
Soldier: No sir
Court Wizard: When, for example, was this horn last tested? It looks dented. How can we be sure it will work?
Soldier: A year ago, test it out or cause a need to use it
Court Wizard: *frown* Mayhap I will speak to the king about such lackness. Or perhaps I can sell him a spell that will serve just as well.
Soldier: Good idea, I agree, go do that *hug court wizard*
Court Wizard: Get off of me, you fool! Who gave you permission to touch me! *hit soldier*
Soldier: To the jail with you *hit court wizard*
```
## Fine-tuning a model
To-do. I haven't documented this yet.

136
waifu/scripts/build_dataset.py Executable file
View File

@ -0,0 +1,136 @@
#!/usr/bin/env python3
import argparse
import hashlib
import importlib
import json
import os
import subprocess
import sys
import typing as t
import logging
from waifu.modules import BaseModule
# TODO(11b): Needs manual maintenance ot keep up-to-date. Consider doing some
# metaprogramming trickery to build this list out instead.
DEFAULT_MODULE_LIST = [
"kajiwoto_pdm:KajiwotoPDM",
"kajiwoto_vdm:KajiwotoVDM",
"light_dialogue_vdm:LightDialogueVDM",
]
DEFAULT_MODULES_STRING = ",".join(DEFAULT_MODULE_LIST)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"-o",
"--output-name",
help="Path to write to. Should not include a file extension.")
parser.add_argument("-m",
"--modules",
default=DEFAULT_MODULES_STRING,
help="List of modules to use, comma-separated.")
parser.add_argument(
"-p",
"--print",
type=int,
help="If given, print this many episodes instead of writing to a file.")
parser.add_argument("-v",
"--verbose",
action="store_true",
help="Enable verbose logging.")
args = parser.parse_args()
logging.basicConfig(
format='[%(asctime)s] [%(levelname)s] %(message)s',
level=logging.DEBUG if args.verbose else logging.INFO,
)
# Sanity check.
if args.output_name and args.print:
raise Exception("--output-name and --print are mutually exclusive.")
modules = _import_modules_from_string(args.modules)
#
# If the print argument was specified, print and exit.
#
if args.print:
idx = 0
for module in modules:
for episode in module():
idx += 1
if idx > args.print:
sys.exit()
# Print a newline to visually separate different episodes.
if idx != 1:
print()
print("--- new episode ---")
print(episode)
#
# Otherwise, proceed with the writing logic.
#
# If no output name is given, we build one from the current git revision
# plus a hash of the given arguments. That way, the same dataset should
# theoretically always have the same output name, which is helpful for
# reproducibility and bailing out early (e.g. if the file already exists).
if args.output_name is None:
args_hash = hashlib.sha256(str(args).encode("utf-8")).hexdigest()[:7]
output_name = f"rev-{_get_git_revision_short_hash()}-args-{args_hash}"
else:
output_name = args.output_name
# Open the output file.
output_filename = f"{output_name}.jsonl"
if os.path.exists(output_filename):
raise Exception(f"{output_filename} already exists, aborting.")
with open(output_filename, "w", encoding="utf-8") as output_file:
# Iterate over each module sequentially, and write the data out into the
# file.
for module in modules:
for episode in module():
json_line = json.dumps({"text": episode})
output_file.write(f"{json_line}\n")
#
# Helpers and CLI entrypoint.
#
def _get_git_revision_short_hash() -> str:
'''Returns the project's short git revision hash.'''
return subprocess.check_output(
["git", "rev-parse", "--short", "HEAD"],
cwd=os.path.join(os.path.dirname(os.path.realpath(__file__)), "..",
"..")).decode("ascii").strip()
def _import_modules_from_string(string: str) -> t.List[t.Type[BaseModule]]:
'''Imports all the module classes from the given, comma-separated string.'''
modules: t.List[t.Type[BaseModule]] = []
for module_and_class_name in string.split(","):
qualified_module_name = "waifu.modules"
try:
module_name, class_name = module_and_class_name.split(":")
qualified_module_name = f"waifu.modules.{module_name}"
except ValueError:
class_name = module_and_class_name
module = importlib.import_module(qualified_module_name)
modules.append(getattr(module, class_name))
return modules
if __name__ == "__main__":
main()