refactor: move non-data related stuff to other repositories in the org
This commit is contained in:
parent
7d385ec13c
commit
23eb4a6ab2
57
README.md
57
README.md
|
@ -1,38 +1,27 @@
|
||||||
# 11b's /wAIfu/ Toolbox
|
# data-toolbox
|
||||||
|
|
||||||
**Note**: This is a _very_ early work-in-progress. Expect the unexpected.
|
This repository contains the implementation of our data munging code.
|
||||||
|
|
||||||
If you're interested in the project's current status, please take a look at the [ROADMAP](./ROADMAP.md) instead, or join the Matrix server for more frequent updates.
|
**Note:** Not very well documented at the moment. Still need to implement automatic downloading of data files and document how to install the project with PDM.
|
||||||
|
|
||||||
## Summary
|
## How does it work?
|
||||||
|
|
||||||
As of the moment I'm writing this, the roadmap for the project's prototype model is basically:
|
In short, it takes raw data from several different sources and parses it. From there, we can quickly experiment with different ways of formatting or augmenting the parsed data to generate a final representation, ready to be used as training data for our models.
|
||||||
|
|
||||||
- Build a dataset
|
The general data flow goes something like this:
|
||||||
- Fine-tune a pre-trained language model on that dataset
|
|
||||||
- Play around, observe behavior and identify what's subpar
|
|
||||||
- Adjust dataset accordingly as to try and address the relevant shortcomings
|
|
||||||
- Repeat.
|
|
||||||
|
|
||||||
This repository is where I'm versioning all the code I've written to accomplish the above.
|
- We start off with raw datasets (see [./toolbox/datasets/](./toolbox/datasets/))
|
||||||
|
|
||||||
In short, here's how it works:
|
|
||||||
|
|
||||||
- We start off with raw datasets (see [/waifu/datasets/](./waifu/datasets/)).
|
|
||||||
- These are basically classes reponsible for giving us raw data. They might, for example, download a `.zip` off the internet, unzip it, read a `.json` file from in there and then return its contents.
|
- These are basically classes reponsible for giving us raw data. They might, for example, download a `.zip` off the internet, unzip it, read a `.json` file from in there and then return its contents.
|
||||||
- Modules then make use of these datasets ([/waifu/modules/](./waifu/modules/)).
|
- Modules then make use of these datasets ([./toolbox/modules/](./toolbox/modules/))
|
||||||
- These are heavily inspired by the papers that introduced LaMDA and BlenderBot3 (and their relevant supporting papers as well).
|
- These are heavily inspired by the papers that introduced LaMDA and BlenderBot3 (and their relevant supporting papers)
|
||||||
- In general, each module is responsible for using a dataset as an input, and processing that data down into text that will be used in the fine-tuning process.
|
- In general, each module is responsible for using a dataset as an input, and processing that data down into episodes, which will then be formatted into a proper dataset to be used in the fine-tuning process.
|
||||||
- A final data file is produced by concatenating the outputs of all the modules. This file is used as an input for the fine-tuning process.
|
|
||||||
|
|
||||||
Here's how I do that:
|
## Building a training dataset
|
||||||
|
|
||||||
## Building the data file(s)
|
The final data file is created with the [build_dataset.py](./toolbox/scripts/build_dataset.py) script:
|
||||||
|
|
||||||
The final data file is created with the [build_dataset.py](./waifu/scripts/build_dataset.py) script:
|
|
||||||
|
|
||||||
```
|
```
|
||||||
$ ./waifu/scripts/build_dataset.py --help
|
$ ./toolbox/scripts/build_dataset.py --help
|
||||||
usage: build_dataset.py [-h] [-o OUTPUT_NAME] [-m MODULES] [-p PRINT] [-v]
|
usage: build_dataset.py [-h] [-o OUTPUT_NAME] [-m MODULES] [-p PRINT] [-v]
|
||||||
|
|
||||||
options:
|
options:
|
||||||
|
@ -51,14 +40,14 @@ The default behavior is to write a file called `rev-{GIT_REVISION_HASH}-args{HAS
|
||||||
The script also has an option to print some examples instead of writing to a file, for debugging/dev purposes. Example usage:
|
The script also has an option to print some examples instead of writing to a file, for debugging/dev purposes. Example usage:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
$ ./waifu/scripts/build_dataset.py --print 1 --modules 'light_dialogue_pdm:LightDialoguePDM' # or -p 1 and -m ...
|
$ ./toolbox/scripts/build_dataset.py --print 1 --modules 'light_dialogue_pdm:LightDialoguePDM' # or -p 1 and -m ...
|
||||||
```
|
```
|
||||||
|
|
||||||
Example output:
|
Example output:
|
||||||
|
|
||||||
```
|
```
|
||||||
--- new episode ---
|
--- new episode ---
|
||||||
Context: You are in the Watchtower.
|
Scenario: You are in the Watchtower.
|
||||||
The tower is the largest section of the castle. It contains an observatory for nighttime scouting, but is also used by the wise men to study the stars. Armed
|
The tower is the largest section of the castle. It contains an observatory for nighttime scouting, but is also used by the wise men to study the stars. Armed
|
||||||
guardsmen are always to be found keeping watch.
|
guardsmen are always to be found keeping watch.
|
||||||
There's an alarm horn here.
|
There's an alarm horn here.
|
||||||
|
@ -79,19 +68,3 @@ Soldier: Good idea, I agree, go do that *hug court wizard*
|
||||||
Court Wizard: Get off of me, you fool! Who gave you permission to touch me! *hit soldier*
|
Court Wizard: Get off of me, you fool! Who gave you permission to touch me! *hit soldier*
|
||||||
Soldier: To the jail with you *hit court wizard*
|
Soldier: To the jail with you *hit court wizard*
|
||||||
```
|
```
|
||||||
|
|
||||||
## Fine-tuning a model
|
|
||||||
|
|
||||||
Due to hardware limitations (read: lack of GPUs with massive amounts of VRAM), I need to make use of ColossalAI's optimizations to be able to fine-tune models. However, their example code for fine-tuning OPT lacks some important stuff. Notably: metric logging (so we can know what is going on) and checkpoint saving/loading.
|
|
||||||
|
|
||||||
I've gone ahead and, using [their example scripts](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/opt) as a starting point, made a slightly adjusted version that's actually usable for real-world scenarios. All that stuff is inside the [training folder](./training/).
|
|
||||||
|
|
||||||
If you don't want to mess with anything, all you need to do is put the built data file at `/training/data/train.json` and invoke [finetune.bash](./training/finetune.bash). To see metrics, you can use Tensorboard by visiting http://localhost:6006 after starting the server like this:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
tensorboard serve --port 6006 --logdir training/checkpoints/runs
|
|
||||||
```
|
|
||||||
|
|
||||||
## Running inference on the fine-tuned model
|
|
||||||
|
|
||||||
To-do: write this up.
|
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
[tool.pdm]
|
[tool.pdm]
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "wAIfu"
|
name = "toolbox"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = "Code for the /wAIfu/ Collective project."
|
description = "Code for ingesting data from several sources, formatting it and creating a training dataset."
|
||||||
authors = [
|
authors = [
|
||||||
{name = "0x000011b", email = "0x000011b@waifu.club"},
|
{name = "0x000011b", email = "0x000011b@proton.me"},
|
||||||
]
|
]
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
license = {text = "AGPL-3.0-only"}
|
license = {text = "AGPL-3.0-only"}
|
||||||
|
@ -31,13 +31,13 @@ debugging = [
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.setuptools]
|
[tool.setuptools]
|
||||||
py-modules = ["waifu"]
|
py-modules = ["toolbox"]
|
||||||
|
|
||||||
[tool.pdm.scripts]
|
[tool.pdm.scripts]
|
||||||
lint = {shell = "pylint --jobs 0 ./waifu/**/*.py"}
|
lint = {shell = "pylint --jobs 0 ./toolbox/**/*.py"}
|
||||||
importcheck = "isort --check --diff waifu"
|
importcheck = "isort --check --diff toolbox"
|
||||||
stylecheck = "yapf --parallel --diff --recursive waifu"
|
stylecheck = "yapf --parallel --diff --recursive toolbox"
|
||||||
typecheck = "mypy --strict waifu"
|
typecheck = "mypy --strict toolbox"
|
||||||
|
|
||||||
[tool.yapf]
|
[tool.yapf]
|
||||||
based_on_style = "google"
|
based_on_style = "google"
|
||||||
|
|
|
@ -4,8 +4,8 @@ import os
|
||||||
import typing as t
|
import typing as t
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from waifu.datasets import BaseDataset
|
from toolbox.datasets import BaseDataset
|
||||||
from waifu.utils.dataset import get_data_path
|
from toolbox.utils.dataset import get_data_path
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
|
@ -5,9 +5,9 @@ import re
|
||||||
import typing as t
|
import typing as t
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from waifu.core.consts import PromptConstants
|
from toolbox.core.consts import PromptConstants
|
||||||
from waifu.datasets import BaseDataset
|
from toolbox.datasets import BaseDataset
|
||||||
from waifu.utils.dataset import get_data_path
|
from toolbox.utils.dataset import get_data_path
|
||||||
|
|
||||||
# The regex used to find message variants (e.g.: `%{Hi|Hello} there!`)
|
# The regex used to find message variants (e.g.: `%{Hi|Hello} there!`)
|
||||||
KAJIWOTO_VARIANT_REGEX = re.compile(r'%{(.+?)}')
|
KAJIWOTO_VARIANT_REGEX = re.compile(r'%{(.+?)}')
|
|
@ -5,8 +5,8 @@ from dataclasses import dataclass
|
||||||
|
|
||||||
import mashumaro
|
import mashumaro
|
||||||
|
|
||||||
from waifu.datasets import BaseDataset
|
from toolbox.datasets import BaseDataset
|
||||||
from waifu.utils.dataset import get_data_path
|
from toolbox.utils.dataset import get_data_path
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
|
@ -6,8 +6,9 @@ from dataclasses import dataclass
|
||||||
import mashumaro
|
import mashumaro
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
from waifu.datasets import BaseDataset
|
from toolbox.datasets import BaseDataset
|
||||||
from waifu.utils.dataset import get_data_path
|
from toolbox.utils.dataset import get_data_path
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class SodaEpisode(mashumaro.DataClassDictMixin):
|
class SodaEpisode(mashumaro.DataClassDictMixin):
|
||||||
|
@ -39,5 +40,3 @@ class SodaDataset(BaseDataset[SodaEpisode]):
|
||||||
relation=df['relation'][i],
|
relation=df['relation'][i],
|
||||||
literal=df['literal'][i]
|
literal=df['literal'][i]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -2,9 +2,9 @@ import logging
|
||||||
import re
|
import re
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
from waifu.core.consts import PromptConstants
|
from toolbox.core.consts import PromptConstants
|
||||||
from waifu.datasets.characterai import CharacterAiDataset
|
from toolbox.datasets.characterai import CharacterAiDataset
|
||||||
from waifu.modules import BaseModule
|
from toolbox.modules import BaseModule
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
|
@ -17,8 +17,8 @@ import re
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
from waifu.modules import BaseModule
|
from toolbox.modules import BaseModule
|
||||||
from waifu.utils.dataset import get_data_path
|
from toolbox.utils.dataset import get_data_path
|
||||||
|
|
||||||
# Matches user mentions, channel links, emotes and maybe other stuff.
|
# Matches user mentions, channel links, emotes and maybe other stuff.
|
||||||
SPECIAL_TOKENS_REGEX = re.compile(r"<[@:#].+?>")
|
SPECIAL_TOKENS_REGEX = re.compile(r"<[@:#].+?>")
|
|
@ -1,10 +1,10 @@
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
from waifu.core.consts import PromptConstants
|
from toolbox.core.consts import PromptConstants
|
||||||
from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
|
from toolbox.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
|
||||||
replace_special_tokens_in)
|
replace_special_tokens_in)
|
||||||
from waifu.modules import BaseModule
|
from toolbox.modules import BaseModule
|
||||||
from waifu.utils.strings import uppercase
|
from toolbox.utils.strings import uppercase
|
||||||
|
|
||||||
|
|
||||||
class KajiwotoPDM(BaseModule):
|
class KajiwotoPDM(BaseModule):
|
|
@ -1,9 +1,9 @@
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
from waifu.core.consts import PromptConstants
|
from toolbox.core.consts import PromptConstants
|
||||||
from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
|
from toolbox.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
|
||||||
replace_special_tokens_in)
|
replace_special_tokens_in)
|
||||||
from waifu.modules import BaseModule
|
from toolbox.modules import BaseModule
|
||||||
|
|
||||||
|
|
||||||
class KajiwotoVDM(BaseModule):
|
class KajiwotoVDM(BaseModule):
|
|
@ -1,9 +1,9 @@
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
from waifu.core.consts import PromptConstants
|
from toolbox.core.consts import PromptConstants
|
||||||
from waifu.datasets.light_dialogue import LightDialogueDataset
|
from toolbox.datasets.light_dialogue import LightDialogueDataset
|
||||||
from waifu.modules import BaseModule
|
from toolbox.modules import BaseModule
|
||||||
from waifu.utils.strings import normalize_string, title_case
|
from toolbox.utils.strings import normalize_string, title_case
|
||||||
|
|
||||||
|
|
||||||
class LightDialoguePDM(BaseModule):
|
class LightDialoguePDM(BaseModule):
|
|
@ -1,8 +1,8 @@
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
from waifu.core.consts import PromptConstants
|
from toolbox.core.consts import PromptConstants
|
||||||
from waifu.datasets.soda import SodaDataset
|
from toolbox.datasets.soda import SodaDataset
|
||||||
from waifu.modules import BaseModule
|
from toolbox.modules import BaseModule
|
||||||
|
|
||||||
|
|
||||||
class SodaPDM(BaseModule):
|
class SodaPDM(BaseModule):
|
||||||
|
@ -13,20 +13,20 @@ class SodaPDM(BaseModule):
|
||||||
episode_messages = []
|
episode_messages = []
|
||||||
# NOTE(TG): We determine which order the speakers go on based on whether the relation is xAttr or not.
|
# NOTE(TG): We determine which order the speakers go on based on whether the relation is xAttr or not.
|
||||||
# This is because some speakers are more abstract concepts rather than concrete names,
|
# This is because some speakers are more abstract concepts rather than concrete names,
|
||||||
# which would make them much more suitable as a bot
|
# which would make them much more suitable as a bot
|
||||||
if episode.relation == "xAttr":
|
if episode.relation == "xAttr":
|
||||||
bot_name = episode.speakers[0]
|
bot_name = episode.speakers[0]
|
||||||
user_name = episode.speakers[1]
|
user_name = episode.speakers[1]
|
||||||
else:
|
else:
|
||||||
user_name = episode.speakers[0]
|
user_name = episode.speakers[0]
|
||||||
bot_name = episode.speakers[1]
|
bot_name = episode.speakers[1]
|
||||||
|
|
||||||
# First, we would want to set the persona.
|
# First, we would want to set the persona.
|
||||||
# However, the only acceptable description of a persona would be when episode.relation is "xAttr", since that directly describes
|
# However, the only acceptable description of a persona would be when episode.relation is "xAttr", since that directly describes
|
||||||
# a person in the conversation.
|
# a person in the conversation.
|
||||||
if episode.relation == "xAttr":
|
if episode.relation == "xAttr":
|
||||||
episode_messages.append(f"{PromptConstants.pdm_prefix_for(bot_name)}: {episode.literal}")
|
episode_messages.append(f"{PromptConstants.pdm_prefix_for(bot_name)}: {episode.literal}")
|
||||||
|
|
||||||
# Next, set the scenario.
|
# Next, set the scenario.
|
||||||
# Make sure to replace any instance of the person representing the user in the conversation with the user token
|
# Make sure to replace any instance of the person representing the user in the conversation with the user token
|
||||||
replaced_narrative = episode.narrative.replace(user_name, PromptConstants.USER_TOKEN)
|
replaced_narrative = episode.narrative.replace(user_name, PromptConstants.USER_TOKEN)
|
||||||
|
@ -34,7 +34,7 @@ class SodaPDM(BaseModule):
|
||||||
episode_messages.append(scenario)
|
episode_messages.append(scenario)
|
||||||
# Next, the start token
|
# Next, the start token
|
||||||
episode_messages.append(PromptConstants.CHAT_START_TOKEN)
|
episode_messages.append(PromptConstants.CHAT_START_TOKEN)
|
||||||
|
|
||||||
# I am going to assume that the length of episode.speakers is the same as the length of episode.dialogue
|
# I am going to assume that the length of episode.speakers is the same as the length of episode.dialogue
|
||||||
# Looked pretty clean to me in the data. Fuck it, TODO: account for the possibility of that happening
|
# Looked pretty clean to me in the data. Fuck it, TODO: account for the possibility of that happening
|
||||||
for i, utterance in enumerate(episode.dialogue):
|
for i, utterance in enumerate(episode.dialogue):
|
||||||
|
@ -44,5 +44,5 @@ class SodaPDM(BaseModule):
|
||||||
else:
|
else:
|
||||||
name = bot_name
|
name = bot_name
|
||||||
episode_messages.append(f"{name}: {utterance.replace(user_name, PromptConstants.USER_TOKEN)}")
|
episode_messages.append(f"{name}: {utterance.replace(user_name, PromptConstants.USER_TOKEN)}")
|
||||||
|
|
||||||
yield episode_messages
|
yield episode_messages
|
|
@ -10,9 +10,9 @@ import subprocess
|
||||||
import sys
|
import sys
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
from waifu.core.consts import PromptConstants
|
from toolbox.core.consts import PromptConstants
|
||||||
from waifu.modules import BaseModule
|
from toolbox.modules import BaseModule
|
||||||
from waifu.utils.strings import contains_suspect_unicode
|
from toolbox.utils.strings import contains_suspect_unicode
|
||||||
|
|
||||||
# TODO(11b): Needs manual maintenance to keep up-to-date. Consider doing some
|
# TODO(11b): Needs manual maintenance to keep up-to-date. Consider doing some
|
||||||
# metaprogramming trickery to build this list out instead.
|
# metaprogramming trickery to build this list out instead.
|
||||||
|
@ -66,7 +66,7 @@ def main() -> None:
|
||||||
level=logging.DEBUG if args.verbose else logging.INFO,
|
level=logging.DEBUG if args.verbose else logging.INFO,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sanity check.
|
# Sanity checks.
|
||||||
if args.output_name and args.print:
|
if args.output_name and args.print:
|
||||||
raise Exception("--output-name and --print are mutually exclusive.")
|
raise Exception("--output-name and --print are mutually exclusive.")
|
||||||
if args.skip and not args.print:
|
if args.skip and not args.print:
|
||||||
|
@ -200,10 +200,10 @@ def _import_modules_from_string(string: str) -> t.List[t.Type[BaseModule]]:
|
||||||
'''Imports all the module classes from the given, comma-separated string.'''
|
'''Imports all the module classes from the given, comma-separated string.'''
|
||||||
modules: t.List[t.Type[BaseModule]] = []
|
modules: t.List[t.Type[BaseModule]] = []
|
||||||
for module_and_class_name in string.split(","):
|
for module_and_class_name in string.split(","):
|
||||||
qualified_module_name = "waifu.modules"
|
qualified_module_name = "toolbox.modules"
|
||||||
try:
|
try:
|
||||||
module_name, class_name = module_and_class_name.split(":")
|
module_name, class_name = module_and_class_name.split(":")
|
||||||
qualified_module_name = f"waifu.modules.{module_name}"
|
qualified_module_name = f"toolbox.modules.{module_name}"
|
||||||
except ValueError:
|
except ValueError:
|
||||||
class_name = module_and_class_name
|
class_name = module_and_class_name
|
||||||
|
|
Loading…
Reference in New Issue