From e0552639faf91ce26b6679562545d0e0aac3539c Mon Sep 17 00:00:00 2001 From: 0x000011b <0x000011b@waifu.club> Date: Fri, 23 Dec 2022 16:20:53 -0300 Subject: [PATCH] feat: update CAI dataset/module to handle userscript dumps and use definitions --- waifu/core/consts.py | 15 ++++ waifu/datasets/characterai.py | 122 +++++++++++++++++++++------ waifu/modules/characterai_pdm.py | 136 +++++++++++++++++++++++++++++-- 3 files changed, 240 insertions(+), 33 deletions(-) create mode 100644 waifu/core/consts.py diff --git a/waifu/core/consts.py b/waifu/core/consts.py new file mode 100644 index 0000000..9594a05 --- /dev/null +++ b/waifu/core/consts.py @@ -0,0 +1,15 @@ +class PromptConstants: + '''String constants related to prompt engineering.''' + + # Prefix for user messages. + USER_PREFIX = "You" + + # Global target word count. The word count is chosen in such a way that we + # can fit all the required prompt trickery into the model's input, but still + # leave enough space for the user's input message and the infernce result. + TARGET_WORD_COUNT_PER_EPISODE = 1536 + + @staticmethod + def pdm_prefix_for(name: str) -> str: + '''Builds the Persona Dialogue Module prefix for a given `name`.''' + return f"{name}'s Persona" diff --git a/waifu/datasets/characterai.py b/waifu/datasets/characterai.py index 38e6ee4..ee76c48 100644 --- a/waifu/datasets/characterai.py +++ b/waifu/datasets/characterai.py @@ -3,45 +3,64 @@ import os import typing as t from dataclasses import dataclass -import mashumaro - from waifu.datasets import BaseDataset from waifu.utils.dataset import get_data_path @dataclass(frozen=True) -class CaiBotInfo(mashumaro.DataClassDictMixin): +class CaiBotInfo: name: str title: str - description: str + description: str | None greeting: str + # Optional because it might be private. + definitions: str | None + + # Useful for when several bots have the same name - we can tell them apart + # by their external_id. + external_id: str + + # There's also categories, but I'm ignoring them for now since I don't think + # they'll be of much use. + @dataclass(frozen=True) class CaiChat: # First message is the bot's greeting, the one afterwards is the user. messages: t.List[str] - bot_info: CaiBotInfo + bot: CaiBotInfo class CharacterAiDataset(BaseDataset[CaiChat]): '''Dataset for CharacterAI dumps.''' def generator(self) -> t.Generator[CaiChat, None, None]: - for folder in _enumerate_bot_folders(): - info_path = os.path.join(folder, "info.json") - histories_path = os.path.join(folder, "histories.json") + bot_id_to_info_dict = {} - with open(info_path, "r", encoding="utf-8") as info_file, \ - open(histories_path, "r", encoding="utf-8") as histories_file: - info_json = json.load(info_file) - histories_json = json.load(histories_file) + # Do a first run through all the files to load all the definitions and + # descriptions. + for data in _available_json_data(): + if not _is_definition_data(data): + continue - bot_info = CaiBotInfo.from_dict(info_json["character"]) + bot_info = _bot_info_from_dict(data["character"]) + bot_id_to_info_dict[bot_info.external_id] = bot_info - for history_dict in histories_json["histories"]: + # Now do a second pass, to actually handle chat histories/messages. + for data in _available_json_data(): + if _is_definition_data(data): + continue + + # Prefer grabbing bot info from a Character Editor dump, if it + # exists. Fall back to public data otherwise. + bot_id = data["info"]["character"]["external_id"] + bot_info = bot_id_to_info_dict.get( + bot_id, _bot_info_from_dict(data["info"]["character"])) + + for history_dict in data["histories"]["histories"]: messages = _messages_from_dict(history_dict["msgs"]) - yield CaiChat(bot_info=bot_info, messages=messages) + yield CaiChat(bot=bot_info, messages=messages) # @@ -49,22 +68,49 @@ class CharacterAiDataset(BaseDataset[CaiChat]): # -def _enumerate_bot_folders() -> list[str]: - '''Returns a list of folders available in the CAI data folder.''' - dataset_path = get_data_path(dataset_name="test_characterai_dumps") - items = os.listdir(dataset_path) +def _enumerate_json_files(root_path: str) -> list[str]: + '''Returns a list of files available in the given `root_path`.''' + items = os.listdir(root_path) - folders: list[str] = [] + files: list[str] = [] for item in items: - item_path = os.path.join(dataset_path, item) - if os.path.isfile(item_path): - # We only care about folders. + item_path = os.path.join(root_path, item) + if not os.path.isfile(item_path) or not item_path.endswith(".json"): + # We only care about JSON files. continue - absolute_folder_path = os.path.abspath(os.path.join(dataset_path, item)) - folders.append(absolute_folder_path) + absolute_file_path = os.path.abspath(os.path.join(root_path, item)) + files.append(absolute_file_path) - return folders + return files + + +def _available_json_data() -> t.Generator[dict[str, t.Any], None, None]: + ''' + Yields all available JSON data, parsed from the files in the CharacterAI + data folder. + ''' + dataset_path = get_data_path(dataset_name="characterai") + + for folder in ["public", "private"]: + folder_path = os.path.join(dataset_path, folder) + for json_file_path in _enumerate_json_files(folder_path): + with open(json_file_path, "r", encoding="utf-8") as json_file: + yield json.load(json_file) + + +def _bot_info_from_dict(info_dict: dict[str, t.Any]) -> CaiBotInfo: + '''Builds a CaiBotInfo object from the `character` field in the JSON.''' + return CaiBotInfo( + name=info_dict["name"], + title=info_dict["title"], + # This comes in as an empty string instead of `null` in the JSON when + # it's not defined for some reason, so we cast to None here for clarity. + description=info_dict["description"] or None, + greeting=info_dict["greeting"], + definitions=info_dict.get("definition"), + external_id=info_dict["external_id"], + ) def _messages_from_dict(msgs_dict: list[dict[str, t.Any]]) -> list[str]: @@ -73,3 +119,27 @@ def _messages_from_dict(msgs_dict: list[dict[str, t.Any]]) -> list[str]: for raw_message in msgs_dict: messages.append(raw_message["text"]) return messages + + +def _is_definition_data(dict_from_json: dict[str, t.Any]) -> bool: + ''' + Figures out whether the given dict (parsed from a JSON file) is a regular + dump, or a dump from the Character Editor (possibly containing definitions). + + If it doesn't seem like either, raises a `ValueError` so we can discard bad + data. + ''' + keys = list(dict_from_json.keys()) + + # Some people messed with their files so the order of the keys isn't always + # the same, so we sort for consistency. + keys.sort() + if keys == ["character"]: + return True + elif keys == ["character", "user__username"]: + return True + elif keys == ["histories", "info"]: + return False + else: + print(dict_from_json) + raise ValueError(f"Unexpected keys found in CAI dump JSON file: {keys}") diff --git a/waifu/modules/characterai_pdm.py b/waifu/modules/characterai_pdm.py index 9f4918b..fb41ee4 100644 --- a/waifu/modules/characterai_pdm.py +++ b/waifu/modules/characterai_pdm.py @@ -1,9 +1,35 @@ +import logging +import re import typing as t +from waifu.core.consts import PromptConstants from waifu.datasets.characterai import CharacterAiDataset from waifu.modules import BaseModule -USER_PREFIX = "You" +logger = logging.getLogger(__name__) + +# Discard episodes shorter than 3 turns. These are likely not very useful for +# the model to learn to converse properly, since they only really contain one +# dialogue response (the first turn is the hardcoded greeting, and the second is +# the user's input). +MIN_EPISODE_LEN = 3 + +# +# So here's a quick rundown of what needs to happen. We have a limited context +# window (of 2048 tokens, ATM) and for the Persona Dialogue Module (PDM), we +# need to fit all of the following things in there: +# +# - The bot's description/definitions/persona/whatever you want to call it +# - Last X messages of chat history/context (the more the merrier, usually) +# - The user's input message, e.g. `You: [user text here]` +# - The bot's response, e.g. `[Bot name]: [space for the bot's response]` +# +# As such, most of the code here is about taking globs of text and +# chunking/splitting them up to make the format described above fit into blocks +# of 2048-ish tokens (not exactly 2048 because the tokenizer depends on the +# model used, and I don't want to create a dependency on a specific model at the +# data processing stage at this point). +# class CharacterAiPDM(BaseModule): @@ -11,15 +37,111 @@ class CharacterAiPDM(BaseModule): def generator(self) -> t.Generator[str, None, None]: for chat in CharacterAiDataset(): - description_string = f"{chat.bot_info.name}'s Description: {chat.bot_info.description}" - # Empty turn to separate description from the messages. - turns = [description_string, ""] + if len(chat.messages) < MIN_EPISODE_LEN: + logger.debug( + "Found episode shorter than minimum length (%s < %s), discarding.", + len(chat.messages), MIN_EPISODE_LEN) + continue + base_turns = [] + if chat.bot.description is not None: + pdm_prefix = PromptConstants.pdm_prefix_for(chat.bot.name) + pdm_string = f"{pdm_prefix}: {chat.bot.description}" + base_turns.append(pdm_string) + + if chat.bot.definitions is not None: + parsed_definitions, parsed_examples = _parse_definitions_for( + chat.bot.name, chat.bot.definitions) + base_turns.append(parsed_definitions) + + # Add an empty turn to separate persona info from messages, if + # necessary. + if len(base_turns) > 0: + base_turns.append("") + + # Now, start adding messages and break episodes apart if they get + # too big. + turns = base_turns.copy() for idx, raw_message in enumerate(chat.messages): + # First message is always the bot (since it must send a + # greeting), and next up is always the user. if idx % 2 == 0: - message = f"{chat.bot_info.name}: {raw_message}" + # TODO(11b): Handle `[NAME_IN_MESSAGE_REDACTED]`. + message = f"{chat.bot.name}: {raw_message}" else: - message = f"{USER_PREFIX}: {raw_message}" + message = f"{PromptConstants.USER_PREFIX}: {raw_message}" turns.append(message) - yield "\n".join(turns) + # Splitting logic. + cur_episode_len = sum([len(x.split()) for x in turns]) + if cur_episode_len > PromptConstants.TARGET_WORD_COUNT_PER_EPISODE: + logger.debug( + "Episode length went over TARGET_WORD_COUNT_PER_EPISODE, breaking apart." + ) + + # Adding the last message made the episode go over the + # target word count, so we return the episode without it... + removed_turn = turns.pop() + yield "\n".join(turns) + + # ...and start the next episode with the message we had to + # trim out from this one. + turns = base_turns.copy() + turns.append(removed_turn) + + +# +# Private helpers. +# + +EXAMPLE_CHAT_REGEX = re.compile( + r"({{char}}|{{random_user_\d}}): (.+?)(?:END_OF_DIALOG)", re.DOTALL) +RELAXED_EXAMPLE_CHAT_REGEX = re.compile(r"{{char}}: .+", re.DOTALL) + + +def _parse_definitions_for(bot_name: str, + raw_definitions: str) -> t.Tuple[str, list[str]]: + ''' + Parses bot definitions. + + This function attempts to find example messages within the input string, + parses them accordingly and returns them separately from the rest of the + text in the original `definitions` string. + ''' + definitions, examples = _parse_definitions_strict(raw_definitions) + if len(examples) == 0: + definitions, examples = _parse_definitions_relaxed(raw_definitions) + + parsed_definitions = definitions.replace("{{char}}", bot_name) + parsed_examples = [x.replace("{{char}}", bot_name) for x in examples] + + return parsed_definitions, parsed_examples + + +def _parse_definitions_strict(definitions: str) -> t.Tuple[str, list[str]]: + ''' + Strict parsing of a bot's definitions string, assumes END_OF_DIALOG was used + correctly by the bot's creator. + ''' + matched_example_chats = EXAMPLE_CHAT_REGEX.finditer(definitions) + examples = [ + x.group().replace("END_OF_DIALOG", "").strip() + for x in matched_example_chats + ] + definitions_without_examples = re.sub(EXAMPLE_CHAT_REGEX, "", definitions) + + return definitions_without_examples, examples + + +def _parse_definitions_relaxed(definitions: str) -> t.Tuple[str, list[str]]: + ''' + Same as the `_parse_definitions_strict`, but this one is much more relaxed + and should be used for when the bot creator didn't properly use + END_OF_DIALOG to delineate example chats. + ''' + matched_example_chats = RELAXED_EXAMPLE_CHAT_REGEX.finditer(definitions) + examples = [x.group().strip() for x in matched_example_chats] + definitions_without_examples = re.sub(RELAXED_EXAMPLE_CHAT_REGEX, "", + definitions) + + return definitions_without_examples, examples