feat: update CAI dataset/module to handle userscript dumps and use definitions
This commit is contained in:
parent
aef9289678
commit
e0552639fa
|
@ -0,0 +1,15 @@
|
||||||
|
class PromptConstants:
|
||||||
|
'''String constants related to prompt engineering.'''
|
||||||
|
|
||||||
|
# Prefix for user messages.
|
||||||
|
USER_PREFIX = "You"
|
||||||
|
|
||||||
|
# Global target word count. The word count is chosen in such a way that we
|
||||||
|
# can fit all the required prompt trickery into the model's input, but still
|
||||||
|
# leave enough space for the user's input message and the infernce result.
|
||||||
|
TARGET_WORD_COUNT_PER_EPISODE = 1536
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pdm_prefix_for(name: str) -> str:
|
||||||
|
'''Builds the Persona Dialogue Module prefix for a given `name`.'''
|
||||||
|
return f"{name}'s Persona"
|
|
@ -3,45 +3,64 @@ import os
|
||||||
import typing as t
|
import typing as t
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import mashumaro
|
|
||||||
|
|
||||||
from waifu.datasets import BaseDataset
|
from waifu.datasets import BaseDataset
|
||||||
from waifu.utils.dataset import get_data_path
|
from waifu.utils.dataset import get_data_path
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class CaiBotInfo(mashumaro.DataClassDictMixin):
|
class CaiBotInfo:
|
||||||
name: str
|
name: str
|
||||||
title: str
|
title: str
|
||||||
description: str
|
description: str | None
|
||||||
greeting: str
|
greeting: str
|
||||||
|
|
||||||
|
# Optional because it might be private.
|
||||||
|
definitions: str | None
|
||||||
|
|
||||||
|
# Useful for when several bots have the same name - we can tell them apart
|
||||||
|
# by their external_id.
|
||||||
|
external_id: str
|
||||||
|
|
||||||
|
# There's also categories, but I'm ignoring them for now since I don't think
|
||||||
|
# they'll be of much use.
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class CaiChat:
|
class CaiChat:
|
||||||
# First message is the bot's greeting, the one afterwards is the user.
|
# First message is the bot's greeting, the one afterwards is the user.
|
||||||
messages: t.List[str]
|
messages: t.List[str]
|
||||||
bot_info: CaiBotInfo
|
bot: CaiBotInfo
|
||||||
|
|
||||||
|
|
||||||
class CharacterAiDataset(BaseDataset[CaiChat]):
|
class CharacterAiDataset(BaseDataset[CaiChat]):
|
||||||
'''Dataset for CharacterAI dumps.'''
|
'''Dataset for CharacterAI dumps.'''
|
||||||
|
|
||||||
def generator(self) -> t.Generator[CaiChat, None, None]:
|
def generator(self) -> t.Generator[CaiChat, None, None]:
|
||||||
for folder in _enumerate_bot_folders():
|
bot_id_to_info_dict = {}
|
||||||
info_path = os.path.join(folder, "info.json")
|
|
||||||
histories_path = os.path.join(folder, "histories.json")
|
|
||||||
|
|
||||||
with open(info_path, "r", encoding="utf-8") as info_file, \
|
# Do a first run through all the files to load all the definitions and
|
||||||
open(histories_path, "r", encoding="utf-8") as histories_file:
|
# descriptions.
|
||||||
info_json = json.load(info_file)
|
for data in _available_json_data():
|
||||||
histories_json = json.load(histories_file)
|
if not _is_definition_data(data):
|
||||||
|
continue
|
||||||
|
|
||||||
bot_info = CaiBotInfo.from_dict(info_json["character"])
|
bot_info = _bot_info_from_dict(data["character"])
|
||||||
|
bot_id_to_info_dict[bot_info.external_id] = bot_info
|
||||||
|
|
||||||
for history_dict in histories_json["histories"]:
|
# Now do a second pass, to actually handle chat histories/messages.
|
||||||
|
for data in _available_json_data():
|
||||||
|
if _is_definition_data(data):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Prefer grabbing bot info from a Character Editor dump, if it
|
||||||
|
# exists. Fall back to public data otherwise.
|
||||||
|
bot_id = data["info"]["character"]["external_id"]
|
||||||
|
bot_info = bot_id_to_info_dict.get(
|
||||||
|
bot_id, _bot_info_from_dict(data["info"]["character"]))
|
||||||
|
|
||||||
|
for history_dict in data["histories"]["histories"]:
|
||||||
messages = _messages_from_dict(history_dict["msgs"])
|
messages = _messages_from_dict(history_dict["msgs"])
|
||||||
yield CaiChat(bot_info=bot_info, messages=messages)
|
yield CaiChat(bot=bot_info, messages=messages)
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -49,22 +68,49 @@ class CharacterAiDataset(BaseDataset[CaiChat]):
|
||||||
#
|
#
|
||||||
|
|
||||||
|
|
||||||
def _enumerate_bot_folders() -> list[str]:
|
def _enumerate_json_files(root_path: str) -> list[str]:
|
||||||
'''Returns a list of folders available in the CAI data folder.'''
|
'''Returns a list of files available in the given `root_path`.'''
|
||||||
dataset_path = get_data_path(dataset_name="test_characterai_dumps")
|
items = os.listdir(root_path)
|
||||||
items = os.listdir(dataset_path)
|
|
||||||
|
|
||||||
folders: list[str] = []
|
files: list[str] = []
|
||||||
for item in items:
|
for item in items:
|
||||||
item_path = os.path.join(dataset_path, item)
|
item_path = os.path.join(root_path, item)
|
||||||
if os.path.isfile(item_path):
|
if not os.path.isfile(item_path) or not item_path.endswith(".json"):
|
||||||
# We only care about folders.
|
# We only care about JSON files.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
absolute_folder_path = os.path.abspath(os.path.join(dataset_path, item))
|
absolute_file_path = os.path.abspath(os.path.join(root_path, item))
|
||||||
folders.append(absolute_folder_path)
|
files.append(absolute_file_path)
|
||||||
|
|
||||||
return folders
|
return files
|
||||||
|
|
||||||
|
|
||||||
|
def _available_json_data() -> t.Generator[dict[str, t.Any], None, None]:
|
||||||
|
'''
|
||||||
|
Yields all available JSON data, parsed from the files in the CharacterAI
|
||||||
|
data folder.
|
||||||
|
'''
|
||||||
|
dataset_path = get_data_path(dataset_name="characterai")
|
||||||
|
|
||||||
|
for folder in ["public", "private"]:
|
||||||
|
folder_path = os.path.join(dataset_path, folder)
|
||||||
|
for json_file_path in _enumerate_json_files(folder_path):
|
||||||
|
with open(json_file_path, "r", encoding="utf-8") as json_file:
|
||||||
|
yield json.load(json_file)
|
||||||
|
|
||||||
|
|
||||||
|
def _bot_info_from_dict(info_dict: dict[str, t.Any]) -> CaiBotInfo:
|
||||||
|
'''Builds a CaiBotInfo object from the `character` field in the JSON.'''
|
||||||
|
return CaiBotInfo(
|
||||||
|
name=info_dict["name"],
|
||||||
|
title=info_dict["title"],
|
||||||
|
# This comes in as an empty string instead of `null` in the JSON when
|
||||||
|
# it's not defined for some reason, so we cast to None here for clarity.
|
||||||
|
description=info_dict["description"] or None,
|
||||||
|
greeting=info_dict["greeting"],
|
||||||
|
definitions=info_dict.get("definition"),
|
||||||
|
external_id=info_dict["external_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _messages_from_dict(msgs_dict: list[dict[str, t.Any]]) -> list[str]:
|
def _messages_from_dict(msgs_dict: list[dict[str, t.Any]]) -> list[str]:
|
||||||
|
@ -73,3 +119,27 @@ def _messages_from_dict(msgs_dict: list[dict[str, t.Any]]) -> list[str]:
|
||||||
for raw_message in msgs_dict:
|
for raw_message in msgs_dict:
|
||||||
messages.append(raw_message["text"])
|
messages.append(raw_message["text"])
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def _is_definition_data(dict_from_json: dict[str, t.Any]) -> bool:
|
||||||
|
'''
|
||||||
|
Figures out whether the given dict (parsed from a JSON file) is a regular
|
||||||
|
dump, or a dump from the Character Editor (possibly containing definitions).
|
||||||
|
|
||||||
|
If it doesn't seem like either, raises a `ValueError` so we can discard bad
|
||||||
|
data.
|
||||||
|
'''
|
||||||
|
keys = list(dict_from_json.keys())
|
||||||
|
|
||||||
|
# Some people messed with their files so the order of the keys isn't always
|
||||||
|
# the same, so we sort for consistency.
|
||||||
|
keys.sort()
|
||||||
|
if keys == ["character"]:
|
||||||
|
return True
|
||||||
|
elif keys == ["character", "user__username"]:
|
||||||
|
return True
|
||||||
|
elif keys == ["histories", "info"]:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
print(dict_from_json)
|
||||||
|
raise ValueError(f"Unexpected keys found in CAI dump JSON file: {keys}")
|
||||||
|
|
|
@ -1,9 +1,35 @@
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
|
from waifu.core.consts import PromptConstants
|
||||||
from waifu.datasets.characterai import CharacterAiDataset
|
from waifu.datasets.characterai import CharacterAiDataset
|
||||||
from waifu.modules import BaseModule
|
from waifu.modules import BaseModule
|
||||||
|
|
||||||
USER_PREFIX = "You"
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Discard episodes shorter than 3 turns. These are likely not very useful for
|
||||||
|
# the model to learn to converse properly, since they only really contain one
|
||||||
|
# dialogue response (the first turn is the hardcoded greeting, and the second is
|
||||||
|
# the user's input).
|
||||||
|
MIN_EPISODE_LEN = 3
|
||||||
|
|
||||||
|
#
|
||||||
|
# So here's a quick rundown of what needs to happen. We have a limited context
|
||||||
|
# window (of 2048 tokens, ATM) and for the Persona Dialogue Module (PDM), we
|
||||||
|
# need to fit all of the following things in there:
|
||||||
|
#
|
||||||
|
# - The bot's description/definitions/persona/whatever you want to call it
|
||||||
|
# - Last X messages of chat history/context (the more the merrier, usually)
|
||||||
|
# - The user's input message, e.g. `You: [user text here]`
|
||||||
|
# - The bot's response, e.g. `[Bot name]: [space for the bot's response]`
|
||||||
|
#
|
||||||
|
# As such, most of the code here is about taking globs of text and
|
||||||
|
# chunking/splitting them up to make the format described above fit into blocks
|
||||||
|
# of 2048-ish tokens (not exactly 2048 because the tokenizer depends on the
|
||||||
|
# model used, and I don't want to create a dependency on a specific model at the
|
||||||
|
# data processing stage at this point).
|
||||||
|
#
|
||||||
|
|
||||||
|
|
||||||
class CharacterAiPDM(BaseModule):
|
class CharacterAiPDM(BaseModule):
|
||||||
|
@ -11,15 +37,111 @@ class CharacterAiPDM(BaseModule):
|
||||||
|
|
||||||
def generator(self) -> t.Generator[str, None, None]:
|
def generator(self) -> t.Generator[str, None, None]:
|
||||||
for chat in CharacterAiDataset():
|
for chat in CharacterAiDataset():
|
||||||
description_string = f"{chat.bot_info.name}'s Description: {chat.bot_info.description}"
|
if len(chat.messages) < MIN_EPISODE_LEN:
|
||||||
# Empty turn to separate description from the messages.
|
logger.debug(
|
||||||
turns = [description_string, ""]
|
"Found episode shorter than minimum length (%s < %s), discarding.",
|
||||||
|
len(chat.messages), MIN_EPISODE_LEN)
|
||||||
|
continue
|
||||||
|
|
||||||
|
base_turns = []
|
||||||
|
if chat.bot.description is not None:
|
||||||
|
pdm_prefix = PromptConstants.pdm_prefix_for(chat.bot.name)
|
||||||
|
pdm_string = f"{pdm_prefix}: {chat.bot.description}"
|
||||||
|
base_turns.append(pdm_string)
|
||||||
|
|
||||||
|
if chat.bot.definitions is not None:
|
||||||
|
parsed_definitions, parsed_examples = _parse_definitions_for(
|
||||||
|
chat.bot.name, chat.bot.definitions)
|
||||||
|
base_turns.append(parsed_definitions)
|
||||||
|
|
||||||
|
# Add an empty turn to separate persona info from messages, if
|
||||||
|
# necessary.
|
||||||
|
if len(base_turns) > 0:
|
||||||
|
base_turns.append("")
|
||||||
|
|
||||||
|
# Now, start adding messages and break episodes apart if they get
|
||||||
|
# too big.
|
||||||
|
turns = base_turns.copy()
|
||||||
for idx, raw_message in enumerate(chat.messages):
|
for idx, raw_message in enumerate(chat.messages):
|
||||||
|
# First message is always the bot (since it must send a
|
||||||
|
# greeting), and next up is always the user.
|
||||||
if idx % 2 == 0:
|
if idx % 2 == 0:
|
||||||
message = f"{chat.bot_info.name}: {raw_message}"
|
# TODO(11b): Handle `[NAME_IN_MESSAGE_REDACTED]`.
|
||||||
|
message = f"{chat.bot.name}: {raw_message}"
|
||||||
else:
|
else:
|
||||||
message = f"{USER_PREFIX}: {raw_message}"
|
message = f"{PromptConstants.USER_PREFIX}: {raw_message}"
|
||||||
turns.append(message)
|
turns.append(message)
|
||||||
|
|
||||||
yield "\n".join(turns)
|
# Splitting logic.
|
||||||
|
cur_episode_len = sum([len(x.split()) for x in turns])
|
||||||
|
if cur_episode_len > PromptConstants.TARGET_WORD_COUNT_PER_EPISODE:
|
||||||
|
logger.debug(
|
||||||
|
"Episode length went over TARGET_WORD_COUNT_PER_EPISODE, breaking apart."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Adding the last message made the episode go over the
|
||||||
|
# target word count, so we return the episode without it...
|
||||||
|
removed_turn = turns.pop()
|
||||||
|
yield "\n".join(turns)
|
||||||
|
|
||||||
|
# ...and start the next episode with the message we had to
|
||||||
|
# trim out from this one.
|
||||||
|
turns = base_turns.copy()
|
||||||
|
turns.append(removed_turn)
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Private helpers.
|
||||||
|
#
|
||||||
|
|
||||||
|
EXAMPLE_CHAT_REGEX = re.compile(
|
||||||
|
r"({{char}}|{{random_user_\d}}): (.+?)(?:END_OF_DIALOG)", re.DOTALL)
|
||||||
|
RELAXED_EXAMPLE_CHAT_REGEX = re.compile(r"{{char}}: .+", re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_definitions_for(bot_name: str,
|
||||||
|
raw_definitions: str) -> t.Tuple[str, list[str]]:
|
||||||
|
'''
|
||||||
|
Parses bot definitions.
|
||||||
|
|
||||||
|
This function attempts to find example messages within the input string,
|
||||||
|
parses them accordingly and returns them separately from the rest of the
|
||||||
|
text in the original `definitions` string.
|
||||||
|
'''
|
||||||
|
definitions, examples = _parse_definitions_strict(raw_definitions)
|
||||||
|
if len(examples) == 0:
|
||||||
|
definitions, examples = _parse_definitions_relaxed(raw_definitions)
|
||||||
|
|
||||||
|
parsed_definitions = definitions.replace("{{char}}", bot_name)
|
||||||
|
parsed_examples = [x.replace("{{char}}", bot_name) for x in examples]
|
||||||
|
|
||||||
|
return parsed_definitions, parsed_examples
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_definitions_strict(definitions: str) -> t.Tuple[str, list[str]]:
|
||||||
|
'''
|
||||||
|
Strict parsing of a bot's definitions string, assumes END_OF_DIALOG was used
|
||||||
|
correctly by the bot's creator.
|
||||||
|
'''
|
||||||
|
matched_example_chats = EXAMPLE_CHAT_REGEX.finditer(definitions)
|
||||||
|
examples = [
|
||||||
|
x.group().replace("END_OF_DIALOG", "").strip()
|
||||||
|
for x in matched_example_chats
|
||||||
|
]
|
||||||
|
definitions_without_examples = re.sub(EXAMPLE_CHAT_REGEX, "", definitions)
|
||||||
|
|
||||||
|
return definitions_without_examples, examples
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_definitions_relaxed(definitions: str) -> t.Tuple[str, list[str]]:
|
||||||
|
'''
|
||||||
|
Same as the `_parse_definitions_strict`, but this one is much more relaxed
|
||||||
|
and should be used for when the bot creator didn't properly use
|
||||||
|
END_OF_DIALOG to delineate example chats.
|
||||||
|
'''
|
||||||
|
matched_example_chats = RELAXED_EXAMPLE_CHAT_REGEX.finditer(definitions)
|
||||||
|
examples = [x.group().strip() for x in matched_example_chats]
|
||||||
|
definitions_without_examples = re.sub(RELAXED_EXAMPLE_CHAT_REGEX, "",
|
||||||
|
definitions)
|
||||||
|
|
||||||
|
return definitions_without_examples, examples
|
||||||
|
|
Loading…
Reference in New Issue