refactor: move non-data related stuff to other repositories in the org

This commit is contained in:
11b 2023-01-08 16:31:37 -03:00
parent 7d385ec13c
commit 23eb4a6ab2
19 changed files with 64 additions and 92 deletions

View File

@ -1,38 +1,27 @@
# 11b's /wAIfu/ Toolbox
# data-toolbox
**Note**: This is a _very_ early work-in-progress. Expect the unexpected.
This repository contains the implementation of our data munging code.
If you're interested in the project's current status, please take a look at the [ROADMAP](./ROADMAP.md) instead, or join the Matrix server for more frequent updates.
**Note:** Not very well documented at the moment. Still need to implement automatic downloading of data files and document how to install the project with PDM.
## Summary
## How does it work?
As of the moment I'm writing this, the roadmap for the project's prototype model is basically:
In short, it takes raw data from several different sources and parses it. From there, we can quickly experiment with different ways of formatting or augmenting the parsed data to generate a final representation, ready to be used as training data for our models.
- Build a dataset
- Fine-tune a pre-trained language model on that dataset
- Play around, observe behavior and identify what's subpar
- Adjust dataset accordingly as to try and address the relevant shortcomings
- Repeat.
The general data flow goes something like this:
This repository is where I'm versioning all the code I've written to accomplish the above.
In short, here's how it works:
- We start off with raw datasets (see [/waifu/datasets/](./waifu/datasets/)).
- We start off with raw datasets (see [./toolbox/datasets/](./toolbox/datasets/))
- These are basically classes reponsible for giving us raw data. They might, for example, download a `.zip` off the internet, unzip it, read a `.json` file from in there and then return its contents.
- Modules then make use of these datasets ([/waifu/modules/](./waifu/modules/)).
- These are heavily inspired by the papers that introduced LaMDA and BlenderBot3 (and their relevant supporting papers as well).
- In general, each module is responsible for using a dataset as an input, and processing that data down into text that will be used in the fine-tuning process.
- A final data file is produced by concatenating the outputs of all the modules. This file is used as an input for the fine-tuning process.
- Modules then make use of these datasets ([./toolbox/modules/](./toolbox/modules/))
- These are heavily inspired by the papers that introduced LaMDA and BlenderBot3 (and their relevant supporting papers)
- In general, each module is responsible for using a dataset as an input, and processing that data down into episodes, which will then be formatted into a proper dataset to be used in the fine-tuning process.
Here's how I do that:
## Building a training dataset
## Building the data file(s)
The final data file is created with the [build_dataset.py](./waifu/scripts/build_dataset.py) script:
The final data file is created with the [build_dataset.py](./toolbox/scripts/build_dataset.py) script:
```
$ ./waifu/scripts/build_dataset.py --help
$ ./toolbox/scripts/build_dataset.py --help
usage: build_dataset.py [-h] [-o OUTPUT_NAME] [-m MODULES] [-p PRINT] [-v]
options:
@ -51,14 +40,14 @@ The default behavior is to write a file called `rev-{GIT_REVISION_HASH}-args{HAS
The script also has an option to print some examples instead of writing to a file, for debugging/dev purposes. Example usage:
```bash
$ ./waifu/scripts/build_dataset.py --print 1 --modules 'light_dialogue_pdm:LightDialoguePDM' # or -p 1 and -m ...
$ ./toolbox/scripts/build_dataset.py --print 1 --modules 'light_dialogue_pdm:LightDialoguePDM' # or -p 1 and -m ...
```
Example output:
```
--- new episode ---
Context: You are in the Watchtower.
Scenario: You are in the Watchtower.
The tower is the largest section of the castle. It contains an observatory for nighttime scouting, but is also used by the wise men to study the stars. Armed
guardsmen are always to be found keeping watch.
There's an alarm horn here.
@ -79,19 +68,3 @@ Soldier: Good idea, I agree, go do that *hug court wizard*
Court Wizard: Get off of me, you fool! Who gave you permission to touch me! *hit soldier*
Soldier: To the jail with you *hit court wizard*
```
## Fine-tuning a model
Due to hardware limitations (read: lack of GPUs with massive amounts of VRAM), I need to make use of ColossalAI's optimizations to be able to fine-tune models. However, their example code for fine-tuning OPT lacks some important stuff. Notably: metric logging (so we can know what is going on) and checkpoint saving/loading.
I've gone ahead and, using [their example scripts](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/opt) as a starting point, made a slightly adjusted version that's actually usable for real-world scenarios. All that stuff is inside the [training folder](./training/).
If you don't want to mess with anything, all you need to do is put the built data file at `/training/data/train.json` and invoke [finetune.bash](./training/finetune.bash). To see metrics, you can use Tensorboard by visiting http://localhost:6006 after starting the server like this:
```bash
tensorboard serve --port 6006 --logdir training/checkpoints/runs
```
## Running inference on the fine-tuned model
To-do: write this up.

View File

@ -1,11 +1,11 @@
[tool.pdm]
[project]
name = "wAIfu"
name = "toolbox"
version = "0.1.0"
description = "Code for the /wAIfu/ Collective project."
description = "Code for ingesting data from several sources, formatting it and creating a training dataset."
authors = [
{name = "0x000011b", email = "0x000011b@waifu.club"},
{name = "0x000011b", email = "0x000011b@proton.me"},
]
requires-python = ">=3.10"
license = {text = "AGPL-3.0-only"}
@ -31,13 +31,13 @@ debugging = [
]
[tool.setuptools]
py-modules = ["waifu"]
py-modules = ["toolbox"]
[tool.pdm.scripts]
lint = {shell = "pylint --jobs 0 ./waifu/**/*.py"}
importcheck = "isort --check --diff waifu"
stylecheck = "yapf --parallel --diff --recursive waifu"
typecheck = "mypy --strict waifu"
lint = {shell = "pylint --jobs 0 ./toolbox/**/*.py"}
importcheck = "isort --check --diff toolbox"
stylecheck = "yapf --parallel --diff --recursive toolbox"
typecheck = "mypy --strict toolbox"
[tool.yapf]
based_on_style = "google"

View File

@ -4,8 +4,8 @@ import os
import typing as t
from dataclasses import dataclass
from waifu.datasets import BaseDataset
from waifu.utils.dataset import get_data_path
from toolbox.datasets import BaseDataset
from toolbox.utils.dataset import get_data_path
logger = logging.getLogger(__name__)

View File

@ -5,9 +5,9 @@ import re
import typing as t
from dataclasses import dataclass
from waifu.core.consts import PromptConstants
from waifu.datasets import BaseDataset
from waifu.utils.dataset import get_data_path
from toolbox.core.consts import PromptConstants
from toolbox.datasets import BaseDataset
from toolbox.utils.dataset import get_data_path
# The regex used to find message variants (e.g.: `%{Hi|Hello} there!`)
KAJIWOTO_VARIANT_REGEX = re.compile(r'%{(.+?)}')

View File

@ -5,8 +5,8 @@ from dataclasses import dataclass
import mashumaro
from waifu.datasets import BaseDataset
from waifu.utils.dataset import get_data_path
from toolbox.datasets import BaseDataset
from toolbox.utils.dataset import get_data_path
@dataclass(frozen=True)

View File

@ -6,8 +6,9 @@ from dataclasses import dataclass
import mashumaro
import pandas as pd
from waifu.datasets import BaseDataset
from waifu.utils.dataset import get_data_path
from toolbox.datasets import BaseDataset
from toolbox.utils.dataset import get_data_path
@dataclass(frozen=True)
class SodaEpisode(mashumaro.DataClassDictMixin):
@ -39,5 +40,3 @@ class SodaDataset(BaseDataset[SodaEpisode]):
relation=df['relation'][i],
literal=df['literal'][i]
)

View File

@ -2,9 +2,9 @@ import logging
import re
import typing as t
from waifu.core.consts import PromptConstants
from waifu.datasets.characterai import CharacterAiDataset
from waifu.modules import BaseModule
from toolbox.core.consts import PromptConstants
from toolbox.datasets.characterai import CharacterAiDataset
from toolbox.modules import BaseModule
logger = logging.getLogger(__name__)

View File

@ -17,8 +17,8 @@ import re
import sqlite3
import typing as t
from waifu.modules import BaseModule
from waifu.utils.dataset import get_data_path
from toolbox.modules import BaseModule
from toolbox.utils.dataset import get_data_path
# Matches user mentions, channel links, emotes and maybe other stuff.
SPECIAL_TOKENS_REGEX = re.compile(r"<[@:#].+?>")

View File

@ -1,10 +1,10 @@
import typing as t
from waifu.core.consts import PromptConstants
from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
from toolbox.core.consts import PromptConstants
from toolbox.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
replace_special_tokens_in)
from waifu.modules import BaseModule
from waifu.utils.strings import uppercase
from toolbox.modules import BaseModule
from toolbox.utils.strings import uppercase
class KajiwotoPDM(BaseModule):

View File

@ -1,9 +1,9 @@
import typing as t
from waifu.core.consts import PromptConstants
from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
from toolbox.core.consts import PromptConstants
from toolbox.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
replace_special_tokens_in)
from waifu.modules import BaseModule
from toolbox.modules import BaseModule
class KajiwotoVDM(BaseModule):

View File

@ -1,9 +1,9 @@
import typing as t
from waifu.core.consts import PromptConstants
from waifu.datasets.light_dialogue import LightDialogueDataset
from waifu.modules import BaseModule
from waifu.utils.strings import normalize_string, title_case
from toolbox.core.consts import PromptConstants
from toolbox.datasets.light_dialogue import LightDialogueDataset
from toolbox.modules import BaseModule
from toolbox.utils.strings import normalize_string, title_case
class LightDialoguePDM(BaseModule):

View File

@ -1,8 +1,8 @@
import typing as t
from waifu.core.consts import PromptConstants
from waifu.datasets.soda import SodaDataset
from waifu.modules import BaseModule
from toolbox.core.consts import PromptConstants
from toolbox.datasets.soda import SodaDataset
from toolbox.modules import BaseModule
class SodaPDM(BaseModule):
@ -13,20 +13,20 @@ class SodaPDM(BaseModule):
episode_messages = []
# NOTE(TG): We determine which order the speakers go on based on whether the relation is xAttr or not.
# This is because some speakers are more abstract concepts rather than concrete names,
# which would make them much more suitable as a bot
# which would make them much more suitable as a bot
if episode.relation == "xAttr":
bot_name = episode.speakers[0]
user_name = episode.speakers[1]
else:
user_name = episode.speakers[0]
bot_name = episode.speakers[1]
# First, we would want to set the persona.
# However, the only acceptable description of a persona would be when episode.relation is "xAttr", since that directly describes
# a person in the conversation.
if episode.relation == "xAttr":
episode_messages.append(f"{PromptConstants.pdm_prefix_for(bot_name)}: {episode.literal}")
# Next, set the scenario.
# Make sure to replace any instance of the person representing the user in the conversation with the user token
replaced_narrative = episode.narrative.replace(user_name, PromptConstants.USER_TOKEN)
@ -34,7 +34,7 @@ class SodaPDM(BaseModule):
episode_messages.append(scenario)
# Next, the start token
episode_messages.append(PromptConstants.CHAT_START_TOKEN)
# I am going to assume that the length of episode.speakers is the same as the length of episode.dialogue
# Looked pretty clean to me in the data. Fuck it, TODO: account for the possibility of that happening
for i, utterance in enumerate(episode.dialogue):
@ -44,5 +44,5 @@ class SodaPDM(BaseModule):
else:
name = bot_name
episode_messages.append(f"{name}: {utterance.replace(user_name, PromptConstants.USER_TOKEN)}")
yield episode_messages
yield episode_messages

View File

@ -10,9 +10,9 @@ import subprocess
import sys
import typing as t
from waifu.core.consts import PromptConstants
from waifu.modules import BaseModule
from waifu.utils.strings import contains_suspect_unicode
from toolbox.core.consts import PromptConstants
from toolbox.modules import BaseModule
from toolbox.utils.strings import contains_suspect_unicode
# TODO(11b): Needs manual maintenance to keep up-to-date. Consider doing some
# metaprogramming trickery to build this list out instead.
@ -66,7 +66,7 @@ def main() -> None:
level=logging.DEBUG if args.verbose else logging.INFO,
)
# Sanity check.
# Sanity checks.
if args.output_name and args.print:
raise Exception("--output-name and --print are mutually exclusive.")
if args.skip and not args.print:
@ -200,10 +200,10 @@ def _import_modules_from_string(string: str) -> t.List[t.Type[BaseModule]]:
'''Imports all the module classes from the given, comma-separated string.'''
modules: t.List[t.Type[BaseModule]] = []
for module_and_class_name in string.split(","):
qualified_module_name = "waifu.modules"
qualified_module_name = "toolbox.modules"
try:
module_name, class_name = module_and_class_name.split(":")
qualified_module_name = f"waifu.modules.{module_name}"
qualified_module_name = f"toolbox.modules.{module_name}"
except ValueError:
class_name = module_and_class_name