From 8df2d873556938a96f92ef4732620e37d00183fb Mon Sep 17 00:00:00 2001 From: 0x000011b <0x000011b@waifu.club> Date: Sat, 17 Dec 2022 21:36:33 -0300 Subject: [PATCH] feat: Kajiwoto dataset and modules --- waifu/datasets/kajiwoto.py | 286 ++++++++++++++++++++++++++++++++++ waifu/modules/kajiwoto_pdm.py | 47 ++++++ waifu/modules/kajiwoto_vdm.py | 26 ++++ 3 files changed, 359 insertions(+) create mode 100644 waifu/datasets/kajiwoto.py create mode 100644 waifu/modules/kajiwoto_pdm.py create mode 100644 waifu/modules/kajiwoto_vdm.py diff --git a/waifu/datasets/kajiwoto.py b/waifu/datasets/kajiwoto.py new file mode 100644 index 0000000..5d2cb8d --- /dev/null +++ b/waifu/datasets/kajiwoto.py @@ -0,0 +1,286 @@ +import json +import os +import typing as t +from dataclasses import dataclass +import re +import logging + +from waifu.datasets import BaseDataset +from waifu.utils.dataset import get_data_path + +# The regex used to find message variants (e.g.: `%{Hi|Hello} there!`) +KAJIWOTO_VARIANT_REGEX = re.compile(r'%{(.+?)}') + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class KajiwotoMessageResponsePair: + message_id: str + bot_id: str + + user_message: str + bot_response: str + condition: str + + +@dataclass(frozen=True) +class BotMetadata: + bot_id: str + name: str + description: str + personalities: t.List[t.List[str]] + has_nsfw: bool + tags: t.List[str] + + +class KajiwotoDataset(BaseDataset[t.List[KajiwotoMessageResponsePair]]): + ''' + The Kajiwoto dataset. + + Takes care of properly handling chat history/message context. + ''' + + def __init__(self) -> None: + self.filepaths = _enumerate_kajiwoto_json_files() + self.cached_metadata: dict[str, BotMetadata] = {} + + def get_metadata_for_bot(self, bot_id: str) -> BotMetadata: + '''Returns known medatada for the given bot ID.''' + if bot_id in self.cached_metadata: + return self.cached_metadata[bot_id] + + dataset_path = get_data_path(dataset_name="kajiwoto") + metadata_filepath = os.path.join(dataset_path, + f"{bot_id}_metadata.json") + + with open(metadata_filepath, "r", encoding="utf-8") as metadata_file: + metadata_dict = json.loads( + metadata_file.read())["data"]["aiTrainerGroup"] + metadata = _metadata_dict_to_dataclass(metadata_dict) + return metadata + + def generator( + self + ) -> t.Generator[t.List[KajiwotoMessageResponsePair], None, None]: + for filepath in self.filepaths: + with open(filepath, "r", encoding="utf-8") as file: + messages = json.loads(file.read())["data"]["aiTrainedList"] + + # So, there's a tricky thing to handle in these datasets which + # is the fact that follow-up messages are saved as completely + # separate entries in the messages array. For example, if we + # have a chat log like: + # + # Human: 1 + # Bot: 2 + # Human: 3 + # Bot: 4 + # + # We will have, in the messages array, something like: + # [ + # {"userMessage": "3", message: "4", "history": ["1"]}, + # {"userMessage": "1", message: "2"}, + # ] + # + # As far as I could tell, whenever a message has a "history" + # field, it usually doesn't make sense by itself. Or even by + # appending history. One needs to look up the original message + # and reply pair using the history field, then build up the + # sequence again manually. + # + # As such, for each file, we need to load the entire thing into + # memory to run over it and build an index to do just that + # (lookups via the history field), so here we go: + history_contents_to_original_msg_idx: dict[str, int] = {} + used_message_indexes: t.Set[int] = set() + + for idx, msg in enumerate(messages): + if msg["history"]: + # Message already references an earlier message-reply + # pair. As far as I could tell, that means _this_ + # specific message can't be referenced, so no point in + # saving an index for it here. + continue + + history_contents_to_original_msg_idx[ + msg["userMessage"]] = idx + + # Now that we have the history index, let's go over _only_ the + # messages that need to be concatenated with their history. + for idx, msg in enumerate(messages): + if not msg.get("history", None): + continue + history_contents = msg["history"][0] + + # Sometimes, a message seems to reference a previous one + # that does not exist. Don't know what's up with that, so + # let's just ignore. + if not history_contents in history_contents_to_original_msg_idx: + continue + + # Fetch the original "history" message to use as context. + original_msg_idx = history_contents_to_original_msg_idx[ + history_contents] + original_msg = messages[original_msg_idx] + + # Yield the conversation episode. + yield [ + _dict_to_dataclass(original_msg), + _dict_to_dataclass(msg), + ] + + # Save the indexes of both of these so we don't re-use them + # without the proper context. + used_message_indexes.add(idx) + used_message_indexes.add(original_msg_idx) + + # Now let's go over regular, history-free messages. + for idx, msg in enumerate(messages): + if idx in used_message_indexes: + continue + + yield [_dict_to_dataclass(msg)] + + +# +# Public helpers. +# + +seen_special_tokens: set[str] = set() +seen_scenes: set[str] = set() + + +def replace_special_tokens_in(string: str) -> str: + ''' + Replaces known special tokens (e.g.: `%{name}`) with their expected + equivalents. + ''' + if (match := re.search(KAJIWOTO_VARIANT_REGEX, string)) is not None: + special_token = match.groups()[0] + if '|' not in special_token and special_token not in seen_special_tokens: + logger.debug("Unhandled Kajiwoto token: %s", special_token) + seen_special_tokens.add(special_token) + + if (scene_match := re.search(r"#scene=(.+?)\b", string)) is not None: + seen_scene = scene_match.groups()[0] + if seen_scene not in seen_scenes: + logger.debug("Unhandled Kajiwoto scene: %s", seen_scene) + seen_scenes.add(seen_scene) + + # TODO: There's lots of these which I haven't handled at all. E.g.: + # %{pronoun} (before and after a dot, so careful with caps), %{name}, + # %{kajiname}, #scene=SOMETHING, ... + return string + + +def generate_variants_for( + string: str, + max_generations: int = 16, + start_counter_at: int = 0) -> t.Generator[str, None, None]: + ''' + Given a string like "%{Hello|Hi} there{.|!}, this should yield: + + - Hello there. + - Hello there! + - Hi there. + - Hi there! + ''' + + # Some bot creators went wild with the variants, which causes ridiculous + # generations if we try to exhaust all possibilities so we cap that here. + # `start_counter_at` is used for keeping track across recursive calls. + counter = start_counter_at + + if (match := re.search(KAJIWOTO_VARIANT_REGEX, string)) is not None: + # Once we have a "%{X|Y|Z}" matched inside the original string, we: + # - Fetch .groups()[0] (which will give us `X|Y|Z`) + # - Split by `|` (so we have ["X", "Y", "Z"]) + # - Filter out empty strings + alternatives = filter(lambda x: x.strip(), match.groups()[0].split("|")) + + # Then, we break the string apart into what comes before and after the + # alternatives, that way we can re-build with "prefix + choice + sufix". + prefix = string[:match.start()] + sufix = string[match.end():] + + for alternative in alternatives: + variant = f'{prefix}{alternative}{sufix}' + + # However, some strings have multiple variant blocks. In that case, + # we operate on them recursively until we have just regular strings + # after generating all possible variants. + still_have_match = re.search(KAJIWOTO_VARIANT_REGEX, + variant) is not None + if still_have_match: + for inner_variant in generate_variants_for( + variant, start_counter_at=counter): + yield inner_variant + + # Keep track and break after `max_generations`. + counter += 1 + if max_generations is not None and counter >= max_generations: + break + else: + yield variant + + # Keep track and break after `max_generations`. + counter += 1 + if max_generations is not None and counter >= max_generations: + break + else: + yield string + + +# +# Private helpers. +# + + +def _enumerate_kajiwoto_json_files() -> list[str]: + ''' + Returns a list of paths to all available `.json` files for the `kajiwoto` + dataset. + ''' + dataset_path = get_data_path(dataset_name="kajiwoto") + items = os.listdir(dataset_path) + files: list[str] = [] + + for item in items: + if not item.endswith(".json"): + # Don't care about other file types. + continue + + if item.endswith("_metadata.json"): + # Don't want to list metadata files here. + continue + + item_path = os.path.join(dataset_path, item) + if not os.path.isfile(item_path): + # Don't care about folders. + continue + + absolute_item_path = os.path.abspath(os.path.join(dataset_path, item)) + files.append(absolute_item_path) + return files + + +def _dict_to_dataclass(obj: dict[str, str]) -> KajiwotoMessageResponsePair: + return KajiwotoMessageResponsePair( + message_id=obj["id"], + bot_id=obj["aiTrainerGroupId"], + condition=obj["condition"], + user_message=obj["userMessage"], + bot_response=obj["message"], + ) + + +def _metadata_dict_to_dataclass(obj: dict[str, t.Any]) -> BotMetadata: + return BotMetadata( + bot_id=obj["id"], + name=obj["name"], + description=obj["description"], + personalities=obj["personalities"], + has_nsfw=obj["nsfw"], + tags=obj["tags"], + ) diff --git a/waifu/modules/kajiwoto_pdm.py b/waifu/modules/kajiwoto_pdm.py new file mode 100644 index 0000000..3b381b3 --- /dev/null +++ b/waifu/modules/kajiwoto_pdm.py @@ -0,0 +1,47 @@ +import typing as t + +from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for, + replace_special_tokens_in) +from waifu.modules import BaseModule +from waifu.utils.strings import uppercase + +USER_PREFIX = "Person 1" +BOT_PREFIX = "Person 2" + + +class KajiwotoPDM(BaseModule): + '''A Persona Dialogue Module powered by the Kajiwoto dataset.''' + + def generator(self) -> t.Generator[str, None, None]: + dataset = KajiwotoDataset() + for episode in dataset: + turns: list[str] = [] + + metadata = dataset.get_metadata_for_bot(episode[0].bot_id) + + # `metadata.personalities` is in a format like: `[["friendly", "20.32"]]` + # but we want that "phrased" closer to natural language, so we build + # `persona_string` to take care of that. + personality_descriptors = [x[0] for x in metadata.personalities] + persona_string = ". ".join( + [uppercase(x) for x in personality_descriptors]) + "." + + description_string = metadata.description.replace("\n", + " ").replace( + " ", " ") + turns.append(f"{BOT_PREFIX}'s Description: {description_string}") + turns.append(f"{BOT_PREFIX}'s Persona: {persona_string}") + + # Empty turn to have a line break separating description/persona + # and the actual messages. + turns.append("") + + for turn in episode: + turns.append(f"{USER_PREFIX}: {turn.user_message}") + turns.append(f"{BOT_PREFIX}: {turn.bot_response}") + + string = "\n".join(turns) + processed_string = replace_special_tokens_in(string) + + for generated_string in generate_variants_for(processed_string): + yield generated_string diff --git a/waifu/modules/kajiwoto_vdm.py b/waifu/modules/kajiwoto_vdm.py new file mode 100644 index 0000000..9770963 --- /dev/null +++ b/waifu/modules/kajiwoto_vdm.py @@ -0,0 +1,26 @@ +import typing as t + +from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for, + replace_special_tokens_in) +from waifu.modules import BaseModule + +USER_PREFIX = "Person 1" +BOT_PREFIX = "Person 2" + + +class KajiwotoVDM(BaseModule): + '''A Vanilla Dialogue Module powered by the Kajiwoto dataset.''' + + def generator(self) -> t.Generator[str, None, None]: + dataset = KajiwotoDataset() + for episode in dataset: + turns: t.List[str] = [] + for turn in episode: + turns.append(f"{USER_PREFIX}: {turn.user_message}") + turns.append(f"{BOT_PREFIX}: {turn.bot_response}") + + string = "\n".join(turns) + processed_string = replace_special_tokens_in(string) + + for generated_string in generate_variants_for(processed_string): + yield generated_string