refactor: move non-data related stuff to other repositories in the org
This commit is contained in:
parent
7d385ec13c
commit
23eb4a6ab2
57
README.md
57
README.md
|
@ -1,38 +1,27 @@
|
|||
# 11b's /wAIfu/ Toolbox
|
||||
# data-toolbox
|
||||
|
||||
**Note**: This is a _very_ early work-in-progress. Expect the unexpected.
|
||||
This repository contains the implementation of our data munging code.
|
||||
|
||||
If you're interested in the project's current status, please take a look at the [ROADMAP](./ROADMAP.md) instead, or join the Matrix server for more frequent updates.
|
||||
**Note:** Not very well documented at the moment. Still need to implement automatic downloading of data files and document how to install the project with PDM.
|
||||
|
||||
## Summary
|
||||
## How does it work?
|
||||
|
||||
As of the moment I'm writing this, the roadmap for the project's prototype model is basically:
|
||||
In short, it takes raw data from several different sources and parses it. From there, we can quickly experiment with different ways of formatting or augmenting the parsed data to generate a final representation, ready to be used as training data for our models.
|
||||
|
||||
- Build a dataset
|
||||
- Fine-tune a pre-trained language model on that dataset
|
||||
- Play around, observe behavior and identify what's subpar
|
||||
- Adjust dataset accordingly as to try and address the relevant shortcomings
|
||||
- Repeat.
|
||||
The general data flow goes something like this:
|
||||
|
||||
This repository is where I'm versioning all the code I've written to accomplish the above.
|
||||
|
||||
In short, here's how it works:
|
||||
|
||||
- We start off with raw datasets (see [/waifu/datasets/](./waifu/datasets/)).
|
||||
- We start off with raw datasets (see [./toolbox/datasets/](./toolbox/datasets/))
|
||||
- These are basically classes reponsible for giving us raw data. They might, for example, download a `.zip` off the internet, unzip it, read a `.json` file from in there and then return its contents.
|
||||
- Modules then make use of these datasets ([/waifu/modules/](./waifu/modules/)).
|
||||
- These are heavily inspired by the papers that introduced LaMDA and BlenderBot3 (and their relevant supporting papers as well).
|
||||
- In general, each module is responsible for using a dataset as an input, and processing that data down into text that will be used in the fine-tuning process.
|
||||
- A final data file is produced by concatenating the outputs of all the modules. This file is used as an input for the fine-tuning process.
|
||||
- Modules then make use of these datasets ([./toolbox/modules/](./toolbox/modules/))
|
||||
- These are heavily inspired by the papers that introduced LaMDA and BlenderBot3 (and their relevant supporting papers)
|
||||
- In general, each module is responsible for using a dataset as an input, and processing that data down into episodes, which will then be formatted into a proper dataset to be used in the fine-tuning process.
|
||||
|
||||
Here's how I do that:
|
||||
## Building a training dataset
|
||||
|
||||
## Building the data file(s)
|
||||
|
||||
The final data file is created with the [build_dataset.py](./waifu/scripts/build_dataset.py) script:
|
||||
The final data file is created with the [build_dataset.py](./toolbox/scripts/build_dataset.py) script:
|
||||
|
||||
```
|
||||
$ ./waifu/scripts/build_dataset.py --help
|
||||
$ ./toolbox/scripts/build_dataset.py --help
|
||||
usage: build_dataset.py [-h] [-o OUTPUT_NAME] [-m MODULES] [-p PRINT] [-v]
|
||||
|
||||
options:
|
||||
|
@ -51,14 +40,14 @@ The default behavior is to write a file called `rev-{GIT_REVISION_HASH}-args{HAS
|
|||
The script also has an option to print some examples instead of writing to a file, for debugging/dev purposes. Example usage:
|
||||
|
||||
```bash
|
||||
$ ./waifu/scripts/build_dataset.py --print 1 --modules 'light_dialogue_pdm:LightDialoguePDM' # or -p 1 and -m ...
|
||||
$ ./toolbox/scripts/build_dataset.py --print 1 --modules 'light_dialogue_pdm:LightDialoguePDM' # or -p 1 and -m ...
|
||||
```
|
||||
|
||||
Example output:
|
||||
|
||||
```
|
||||
--- new episode ---
|
||||
Context: You are in the Watchtower.
|
||||
Scenario: You are in the Watchtower.
|
||||
The tower is the largest section of the castle. It contains an observatory for nighttime scouting, but is also used by the wise men to study the stars. Armed
|
||||
guardsmen are always to be found keeping watch.
|
||||
There's an alarm horn here.
|
||||
|
@ -79,19 +68,3 @@ Soldier: Good idea, I agree, go do that *hug court wizard*
|
|||
Court Wizard: Get off of me, you fool! Who gave you permission to touch me! *hit soldier*
|
||||
Soldier: To the jail with you *hit court wizard*
|
||||
```
|
||||
|
||||
## Fine-tuning a model
|
||||
|
||||
Due to hardware limitations (read: lack of GPUs with massive amounts of VRAM), I need to make use of ColossalAI's optimizations to be able to fine-tune models. However, their example code for fine-tuning OPT lacks some important stuff. Notably: metric logging (so we can know what is going on) and checkpoint saving/loading.
|
||||
|
||||
I've gone ahead and, using [their example scripts](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/opt) as a starting point, made a slightly adjusted version that's actually usable for real-world scenarios. All that stuff is inside the [training folder](./training/).
|
||||
|
||||
If you don't want to mess with anything, all you need to do is put the built data file at `/training/data/train.json` and invoke [finetune.bash](./training/finetune.bash). To see metrics, you can use Tensorboard by visiting http://localhost:6006 after starting the server like this:
|
||||
|
||||
```bash
|
||||
tensorboard serve --port 6006 --logdir training/checkpoints/runs
|
||||
```
|
||||
|
||||
## Running inference on the fine-tuned model
|
||||
|
||||
To-do: write this up.
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
[tool.pdm]
|
||||
|
||||
[project]
|
||||
name = "wAIfu"
|
||||
name = "toolbox"
|
||||
version = "0.1.0"
|
||||
description = "Code for the /wAIfu/ Collective project."
|
||||
description = "Code for ingesting data from several sources, formatting it and creating a training dataset."
|
||||
authors = [
|
||||
{name = "0x000011b", email = "0x000011b@waifu.club"},
|
||||
{name = "0x000011b", email = "0x000011b@proton.me"},
|
||||
]
|
||||
requires-python = ">=3.10"
|
||||
license = {text = "AGPL-3.0-only"}
|
||||
|
@ -31,13 +31,13 @@ debugging = [
|
|||
]
|
||||
|
||||
[tool.setuptools]
|
||||
py-modules = ["waifu"]
|
||||
py-modules = ["toolbox"]
|
||||
|
||||
[tool.pdm.scripts]
|
||||
lint = {shell = "pylint --jobs 0 ./waifu/**/*.py"}
|
||||
importcheck = "isort --check --diff waifu"
|
||||
stylecheck = "yapf --parallel --diff --recursive waifu"
|
||||
typecheck = "mypy --strict waifu"
|
||||
lint = {shell = "pylint --jobs 0 ./toolbox/**/*.py"}
|
||||
importcheck = "isort --check --diff toolbox"
|
||||
stylecheck = "yapf --parallel --diff --recursive toolbox"
|
||||
typecheck = "mypy --strict toolbox"
|
||||
|
||||
[tool.yapf]
|
||||
based_on_style = "google"
|
||||
|
|
|
@ -4,8 +4,8 @@ import os
|
|||
import typing as t
|
||||
from dataclasses import dataclass
|
||||
|
||||
from waifu.datasets import BaseDataset
|
||||
from waifu.utils.dataset import get_data_path
|
||||
from toolbox.datasets import BaseDataset
|
||||
from toolbox.utils.dataset import get_data_path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -5,9 +5,9 @@ import re
|
|||
import typing as t
|
||||
from dataclasses import dataclass
|
||||
|
||||
from waifu.core.consts import PromptConstants
|
||||
from waifu.datasets import BaseDataset
|
||||
from waifu.utils.dataset import get_data_path
|
||||
from toolbox.core.consts import PromptConstants
|
||||
from toolbox.datasets import BaseDataset
|
||||
from toolbox.utils.dataset import get_data_path
|
||||
|
||||
# The regex used to find message variants (e.g.: `%{Hi|Hello} there!`)
|
||||
KAJIWOTO_VARIANT_REGEX = re.compile(r'%{(.+?)}')
|
|
@ -5,8 +5,8 @@ from dataclasses import dataclass
|
|||
|
||||
import mashumaro
|
||||
|
||||
from waifu.datasets import BaseDataset
|
||||
from waifu.utils.dataset import get_data_path
|
||||
from toolbox.datasets import BaseDataset
|
||||
from toolbox.utils.dataset import get_data_path
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
|
@ -6,8 +6,9 @@ from dataclasses import dataclass
|
|||
import mashumaro
|
||||
import pandas as pd
|
||||
|
||||
from waifu.datasets import BaseDataset
|
||||
from waifu.utils.dataset import get_data_path
|
||||
from toolbox.datasets import BaseDataset
|
||||
from toolbox.utils.dataset import get_data_path
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SodaEpisode(mashumaro.DataClassDictMixin):
|
||||
|
@ -39,5 +40,3 @@ class SodaDataset(BaseDataset[SodaEpisode]):
|
|||
relation=df['relation'][i],
|
||||
literal=df['literal'][i]
|
||||
)
|
||||
|
||||
|
|
@ -2,9 +2,9 @@ import logging
|
|||
import re
|
||||
import typing as t
|
||||
|
||||
from waifu.core.consts import PromptConstants
|
||||
from waifu.datasets.characterai import CharacterAiDataset
|
||||
from waifu.modules import BaseModule
|
||||
from toolbox.core.consts import PromptConstants
|
||||
from toolbox.datasets.characterai import CharacterAiDataset
|
||||
from toolbox.modules import BaseModule
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -17,8 +17,8 @@ import re
|
|||
import sqlite3
|
||||
import typing as t
|
||||
|
||||
from waifu.modules import BaseModule
|
||||
from waifu.utils.dataset import get_data_path
|
||||
from toolbox.modules import BaseModule
|
||||
from toolbox.utils.dataset import get_data_path
|
||||
|
||||
# Matches user mentions, channel links, emotes and maybe other stuff.
|
||||
SPECIAL_TOKENS_REGEX = re.compile(r"<[@:#].+?>")
|
|
@ -1,10 +1,10 @@
|
|||
import typing as t
|
||||
|
||||
from waifu.core.consts import PromptConstants
|
||||
from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
|
||||
from toolbox.core.consts import PromptConstants
|
||||
from toolbox.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
|
||||
replace_special_tokens_in)
|
||||
from waifu.modules import BaseModule
|
||||
from waifu.utils.strings import uppercase
|
||||
from toolbox.modules import BaseModule
|
||||
from toolbox.utils.strings import uppercase
|
||||
|
||||
|
||||
class KajiwotoPDM(BaseModule):
|
|
@ -1,9 +1,9 @@
|
|||
import typing as t
|
||||
|
||||
from waifu.core.consts import PromptConstants
|
||||
from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
|
||||
from toolbox.core.consts import PromptConstants
|
||||
from toolbox.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
|
||||
replace_special_tokens_in)
|
||||
from waifu.modules import BaseModule
|
||||
from toolbox.modules import BaseModule
|
||||
|
||||
|
||||
class KajiwotoVDM(BaseModule):
|
|
@ -1,9 +1,9 @@
|
|||
import typing as t
|
||||
|
||||
from waifu.core.consts import PromptConstants
|
||||
from waifu.datasets.light_dialogue import LightDialogueDataset
|
||||
from waifu.modules import BaseModule
|
||||
from waifu.utils.strings import normalize_string, title_case
|
||||
from toolbox.core.consts import PromptConstants
|
||||
from toolbox.datasets.light_dialogue import LightDialogueDataset
|
||||
from toolbox.modules import BaseModule
|
||||
from toolbox.utils.strings import normalize_string, title_case
|
||||
|
||||
|
||||
class LightDialoguePDM(BaseModule):
|
|
@ -1,8 +1,8 @@
|
|||
import typing as t
|
||||
|
||||
from waifu.core.consts import PromptConstants
|
||||
from waifu.datasets.soda import SodaDataset
|
||||
from waifu.modules import BaseModule
|
||||
from toolbox.core.consts import PromptConstants
|
||||
from toolbox.datasets.soda import SodaDataset
|
||||
from toolbox.modules import BaseModule
|
||||
|
||||
|
||||
class SodaPDM(BaseModule):
|
||||
|
@ -13,20 +13,20 @@ class SodaPDM(BaseModule):
|
|||
episode_messages = []
|
||||
# NOTE(TG): We determine which order the speakers go on based on whether the relation is xAttr or not.
|
||||
# This is because some speakers are more abstract concepts rather than concrete names,
|
||||
# which would make them much more suitable as a bot
|
||||
# which would make them much more suitable as a bot
|
||||
if episode.relation == "xAttr":
|
||||
bot_name = episode.speakers[0]
|
||||
user_name = episode.speakers[1]
|
||||
else:
|
||||
user_name = episode.speakers[0]
|
||||
bot_name = episode.speakers[1]
|
||||
|
||||
|
||||
# First, we would want to set the persona.
|
||||
# However, the only acceptable description of a persona would be when episode.relation is "xAttr", since that directly describes
|
||||
# a person in the conversation.
|
||||
if episode.relation == "xAttr":
|
||||
episode_messages.append(f"{PromptConstants.pdm_prefix_for(bot_name)}: {episode.literal}")
|
||||
|
||||
|
||||
# Next, set the scenario.
|
||||
# Make sure to replace any instance of the person representing the user in the conversation with the user token
|
||||
replaced_narrative = episode.narrative.replace(user_name, PromptConstants.USER_TOKEN)
|
||||
|
@ -34,7 +34,7 @@ class SodaPDM(BaseModule):
|
|||
episode_messages.append(scenario)
|
||||
# Next, the start token
|
||||
episode_messages.append(PromptConstants.CHAT_START_TOKEN)
|
||||
|
||||
|
||||
# I am going to assume that the length of episode.speakers is the same as the length of episode.dialogue
|
||||
# Looked pretty clean to me in the data. Fuck it, TODO: account for the possibility of that happening
|
||||
for i, utterance in enumerate(episode.dialogue):
|
||||
|
@ -44,5 +44,5 @@ class SodaPDM(BaseModule):
|
|||
else:
|
||||
name = bot_name
|
||||
episode_messages.append(f"{name}: {utterance.replace(user_name, PromptConstants.USER_TOKEN)}")
|
||||
|
||||
yield episode_messages
|
||||
|
||||
yield episode_messages
|
|
@ -10,9 +10,9 @@ import subprocess
|
|||
import sys
|
||||
import typing as t
|
||||
|
||||
from waifu.core.consts import PromptConstants
|
||||
from waifu.modules import BaseModule
|
||||
from waifu.utils.strings import contains_suspect_unicode
|
||||
from toolbox.core.consts import PromptConstants
|
||||
from toolbox.modules import BaseModule
|
||||
from toolbox.utils.strings import contains_suspect_unicode
|
||||
|
||||
# TODO(11b): Needs manual maintenance to keep up-to-date. Consider doing some
|
||||
# metaprogramming trickery to build this list out instead.
|
||||
|
@ -66,7 +66,7 @@ def main() -> None:
|
|||
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||
)
|
||||
|
||||
# Sanity check.
|
||||
# Sanity checks.
|
||||
if args.output_name and args.print:
|
||||
raise Exception("--output-name and --print are mutually exclusive.")
|
||||
if args.skip and not args.print:
|
||||
|
@ -200,10 +200,10 @@ def _import_modules_from_string(string: str) -> t.List[t.Type[BaseModule]]:
|
|||
'''Imports all the module classes from the given, comma-separated string.'''
|
||||
modules: t.List[t.Type[BaseModule]] = []
|
||||
for module_and_class_name in string.split(","):
|
||||
qualified_module_name = "waifu.modules"
|
||||
qualified_module_name = "toolbox.modules"
|
||||
try:
|
||||
module_name, class_name = module_and_class_name.split(":")
|
||||
qualified_module_name = f"waifu.modules.{module_name}"
|
||||
qualified_module_name = f"toolbox.modules.{module_name}"
|
||||
except ValueError:
|
||||
class_name = module_and_class_name
|
||||
|
Loading…
Reference in New Issue