feat: Kajiwoto dataset and modules
This commit is contained in:
parent
a076746f9d
commit
8df2d87355
|
@ -0,0 +1,286 @@
|
|||
import json
|
||||
import os
|
||||
import typing as t
|
||||
from dataclasses import dataclass
|
||||
import re
|
||||
import logging
|
||||
|
||||
from waifu.datasets import BaseDataset
|
||||
from waifu.utils.dataset import get_data_path
|
||||
|
||||
# The regex used to find message variants (e.g.: `%{Hi|Hello} there!`)
|
||||
KAJIWOTO_VARIANT_REGEX = re.compile(r'%{(.+?)}')
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class KajiwotoMessageResponsePair:
|
||||
message_id: str
|
||||
bot_id: str
|
||||
|
||||
user_message: str
|
||||
bot_response: str
|
||||
condition: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BotMetadata:
|
||||
bot_id: str
|
||||
name: str
|
||||
description: str
|
||||
personalities: t.List[t.List[str]]
|
||||
has_nsfw: bool
|
||||
tags: t.List[str]
|
||||
|
||||
|
||||
class KajiwotoDataset(BaseDataset[t.List[KajiwotoMessageResponsePair]]):
|
||||
'''
|
||||
The Kajiwoto dataset.
|
||||
|
||||
Takes care of properly handling chat history/message context.
|
||||
'''
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.filepaths = _enumerate_kajiwoto_json_files()
|
||||
self.cached_metadata: dict[str, BotMetadata] = {}
|
||||
|
||||
def get_metadata_for_bot(self, bot_id: str) -> BotMetadata:
|
||||
'''Returns known medatada for the given bot ID.'''
|
||||
if bot_id in self.cached_metadata:
|
||||
return self.cached_metadata[bot_id]
|
||||
|
||||
dataset_path = get_data_path(dataset_name="kajiwoto")
|
||||
metadata_filepath = os.path.join(dataset_path,
|
||||
f"{bot_id}_metadata.json")
|
||||
|
||||
with open(metadata_filepath, "r", encoding="utf-8") as metadata_file:
|
||||
metadata_dict = json.loads(
|
||||
metadata_file.read())["data"]["aiTrainerGroup"]
|
||||
metadata = _metadata_dict_to_dataclass(metadata_dict)
|
||||
return metadata
|
||||
|
||||
def generator(
|
||||
self
|
||||
) -> t.Generator[t.List[KajiwotoMessageResponsePair], None, None]:
|
||||
for filepath in self.filepaths:
|
||||
with open(filepath, "r", encoding="utf-8") as file:
|
||||
messages = json.loads(file.read())["data"]["aiTrainedList"]
|
||||
|
||||
# So, there's a tricky thing to handle in these datasets which
|
||||
# is the fact that follow-up messages are saved as completely
|
||||
# separate entries in the messages array. For example, if we
|
||||
# have a chat log like:
|
||||
#
|
||||
# Human: 1
|
||||
# Bot: 2
|
||||
# Human: 3
|
||||
# Bot: 4
|
||||
#
|
||||
# We will have, in the messages array, something like:
|
||||
# [
|
||||
# {"userMessage": "3", message: "4", "history": ["1"]},
|
||||
# {"userMessage": "1", message: "2"},
|
||||
# ]
|
||||
#
|
||||
# As far as I could tell, whenever a message has a "history"
|
||||
# field, it usually doesn't make sense by itself. Or even by
|
||||
# appending history. One needs to look up the original message
|
||||
# and reply pair using the history field, then build up the
|
||||
# sequence again manually.
|
||||
#
|
||||
# As such, for each file, we need to load the entire thing into
|
||||
# memory to run over it and build an index to do just that
|
||||
# (lookups via the history field), so here we go:
|
||||
history_contents_to_original_msg_idx: dict[str, int] = {}
|
||||
used_message_indexes: t.Set[int] = set()
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
if msg["history"]:
|
||||
# Message already references an earlier message-reply
|
||||
# pair. As far as I could tell, that means _this_
|
||||
# specific message can't be referenced, so no point in
|
||||
# saving an index for it here.
|
||||
continue
|
||||
|
||||
history_contents_to_original_msg_idx[
|
||||
msg["userMessage"]] = idx
|
||||
|
||||
# Now that we have the history index, let's go over _only_ the
|
||||
# messages that need to be concatenated with their history.
|
||||
for idx, msg in enumerate(messages):
|
||||
if not msg.get("history", None):
|
||||
continue
|
||||
history_contents = msg["history"][0]
|
||||
|
||||
# Sometimes, a message seems to reference a previous one
|
||||
# that does not exist. Don't know what's up with that, so
|
||||
# let's just ignore.
|
||||
if not history_contents in history_contents_to_original_msg_idx:
|
||||
continue
|
||||
|
||||
# Fetch the original "history" message to use as context.
|
||||
original_msg_idx = history_contents_to_original_msg_idx[
|
||||
history_contents]
|
||||
original_msg = messages[original_msg_idx]
|
||||
|
||||
# Yield the conversation episode.
|
||||
yield [
|
||||
_dict_to_dataclass(original_msg),
|
||||
_dict_to_dataclass(msg),
|
||||
]
|
||||
|
||||
# Save the indexes of both of these so we don't re-use them
|
||||
# without the proper context.
|
||||
used_message_indexes.add(idx)
|
||||
used_message_indexes.add(original_msg_idx)
|
||||
|
||||
# Now let's go over regular, history-free messages.
|
||||
for idx, msg in enumerate(messages):
|
||||
if idx in used_message_indexes:
|
||||
continue
|
||||
|
||||
yield [_dict_to_dataclass(msg)]
|
||||
|
||||
|
||||
#
|
||||
# Public helpers.
|
||||
#
|
||||
|
||||
seen_special_tokens: set[str] = set()
|
||||
seen_scenes: set[str] = set()
|
||||
|
||||
|
||||
def replace_special_tokens_in(string: str) -> str:
|
||||
'''
|
||||
Replaces known special tokens (e.g.: `%{name}`) with their expected
|
||||
equivalents.
|
||||
'''
|
||||
if (match := re.search(KAJIWOTO_VARIANT_REGEX, string)) is not None:
|
||||
special_token = match.groups()[0]
|
||||
if '|' not in special_token and special_token not in seen_special_tokens:
|
||||
logger.debug("Unhandled Kajiwoto token: %s", special_token)
|
||||
seen_special_tokens.add(special_token)
|
||||
|
||||
if (scene_match := re.search(r"#scene=(.+?)\b", string)) is not None:
|
||||
seen_scene = scene_match.groups()[0]
|
||||
if seen_scene not in seen_scenes:
|
||||
logger.debug("Unhandled Kajiwoto scene: %s", seen_scene)
|
||||
seen_scenes.add(seen_scene)
|
||||
|
||||
# TODO: There's lots of these which I haven't handled at all. E.g.:
|
||||
# %{pronoun} (before and after a dot, so careful with caps), %{name},
|
||||
# %{kajiname}, #scene=SOMETHING, ...
|
||||
return string
|
||||
|
||||
|
||||
def generate_variants_for(
|
||||
string: str,
|
||||
max_generations: int = 16,
|
||||
start_counter_at: int = 0) -> t.Generator[str, None, None]:
|
||||
'''
|
||||
Given a string like "%{Hello|Hi} there{.|!}, this should yield:
|
||||
|
||||
- Hello there.
|
||||
- Hello there!
|
||||
- Hi there.
|
||||
- Hi there!
|
||||
'''
|
||||
|
||||
# Some bot creators went wild with the variants, which causes ridiculous
|
||||
# generations if we try to exhaust all possibilities so we cap that here.
|
||||
# `start_counter_at` is used for keeping track across recursive calls.
|
||||
counter = start_counter_at
|
||||
|
||||
if (match := re.search(KAJIWOTO_VARIANT_REGEX, string)) is not None:
|
||||
# Once we have a "%{X|Y|Z}" matched inside the original string, we:
|
||||
# - Fetch .groups()[0] (which will give us `X|Y|Z`)
|
||||
# - Split by `|` (so we have ["X", "Y", "Z"])
|
||||
# - Filter out empty strings
|
||||
alternatives = filter(lambda x: x.strip(), match.groups()[0].split("|"))
|
||||
|
||||
# Then, we break the string apart into what comes before and after the
|
||||
# alternatives, that way we can re-build with "prefix + choice + sufix".
|
||||
prefix = string[:match.start()]
|
||||
sufix = string[match.end():]
|
||||
|
||||
for alternative in alternatives:
|
||||
variant = f'{prefix}{alternative}{sufix}'
|
||||
|
||||
# However, some strings have multiple variant blocks. In that case,
|
||||
# we operate on them recursively until we have just regular strings
|
||||
# after generating all possible variants.
|
||||
still_have_match = re.search(KAJIWOTO_VARIANT_REGEX,
|
||||
variant) is not None
|
||||
if still_have_match:
|
||||
for inner_variant in generate_variants_for(
|
||||
variant, start_counter_at=counter):
|
||||
yield inner_variant
|
||||
|
||||
# Keep track and break after `max_generations`.
|
||||
counter += 1
|
||||
if max_generations is not None and counter >= max_generations:
|
||||
break
|
||||
else:
|
||||
yield variant
|
||||
|
||||
# Keep track and break after `max_generations`.
|
||||
counter += 1
|
||||
if max_generations is not None and counter >= max_generations:
|
||||
break
|
||||
else:
|
||||
yield string
|
||||
|
||||
|
||||
#
|
||||
# Private helpers.
|
||||
#
|
||||
|
||||
|
||||
def _enumerate_kajiwoto_json_files() -> list[str]:
|
||||
'''
|
||||
Returns a list of paths to all available `.json` files for the `kajiwoto`
|
||||
dataset.
|
||||
'''
|
||||
dataset_path = get_data_path(dataset_name="kajiwoto")
|
||||
items = os.listdir(dataset_path)
|
||||
files: list[str] = []
|
||||
|
||||
for item in items:
|
||||
if not item.endswith(".json"):
|
||||
# Don't care about other file types.
|
||||
continue
|
||||
|
||||
if item.endswith("_metadata.json"):
|
||||
# Don't want to list metadata files here.
|
||||
continue
|
||||
|
||||
item_path = os.path.join(dataset_path, item)
|
||||
if not os.path.isfile(item_path):
|
||||
# Don't care about folders.
|
||||
continue
|
||||
|
||||
absolute_item_path = os.path.abspath(os.path.join(dataset_path, item))
|
||||
files.append(absolute_item_path)
|
||||
return files
|
||||
|
||||
|
||||
def _dict_to_dataclass(obj: dict[str, str]) -> KajiwotoMessageResponsePair:
|
||||
return KajiwotoMessageResponsePair(
|
||||
message_id=obj["id"],
|
||||
bot_id=obj["aiTrainerGroupId"],
|
||||
condition=obj["condition"],
|
||||
user_message=obj["userMessage"],
|
||||
bot_response=obj["message"],
|
||||
)
|
||||
|
||||
|
||||
def _metadata_dict_to_dataclass(obj: dict[str, t.Any]) -> BotMetadata:
|
||||
return BotMetadata(
|
||||
bot_id=obj["id"],
|
||||
name=obj["name"],
|
||||
description=obj["description"],
|
||||
personalities=obj["personalities"],
|
||||
has_nsfw=obj["nsfw"],
|
||||
tags=obj["tags"],
|
||||
)
|
|
@ -0,0 +1,47 @@
|
|||
import typing as t
|
||||
|
||||
from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
|
||||
replace_special_tokens_in)
|
||||
from waifu.modules import BaseModule
|
||||
from waifu.utils.strings import uppercase
|
||||
|
||||
USER_PREFIX = "Person 1"
|
||||
BOT_PREFIX = "Person 2"
|
||||
|
||||
|
||||
class KajiwotoPDM(BaseModule):
|
||||
'''A Persona Dialogue Module powered by the Kajiwoto dataset.'''
|
||||
|
||||
def generator(self) -> t.Generator[str, None, None]:
|
||||
dataset = KajiwotoDataset()
|
||||
for episode in dataset:
|
||||
turns: list[str] = []
|
||||
|
||||
metadata = dataset.get_metadata_for_bot(episode[0].bot_id)
|
||||
|
||||
# `metadata.personalities` is in a format like: `[["friendly", "20.32"]]`
|
||||
# but we want that "phrased" closer to natural language, so we build
|
||||
# `persona_string` to take care of that.
|
||||
personality_descriptors = [x[0] for x in metadata.personalities]
|
||||
persona_string = ". ".join(
|
||||
[uppercase(x) for x in personality_descriptors]) + "."
|
||||
|
||||
description_string = metadata.description.replace("\n",
|
||||
" ").replace(
|
||||
" ", " ")
|
||||
turns.append(f"{BOT_PREFIX}'s Description: {description_string}")
|
||||
turns.append(f"{BOT_PREFIX}'s Persona: {persona_string}")
|
||||
|
||||
# Empty turn to have a line break separating description/persona
|
||||
# and the actual messages.
|
||||
turns.append("")
|
||||
|
||||
for turn in episode:
|
||||
turns.append(f"{USER_PREFIX}: {turn.user_message}")
|
||||
turns.append(f"{BOT_PREFIX}: {turn.bot_response}")
|
||||
|
||||
string = "\n".join(turns)
|
||||
processed_string = replace_special_tokens_in(string)
|
||||
|
||||
for generated_string in generate_variants_for(processed_string):
|
||||
yield generated_string
|
|
@ -0,0 +1,26 @@
|
|||
import typing as t
|
||||
|
||||
from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
|
||||
replace_special_tokens_in)
|
||||
from waifu.modules import BaseModule
|
||||
|
||||
USER_PREFIX = "Person 1"
|
||||
BOT_PREFIX = "Person 2"
|
||||
|
||||
|
||||
class KajiwotoVDM(BaseModule):
|
||||
'''A Vanilla Dialogue Module powered by the Kajiwoto dataset.'''
|
||||
|
||||
def generator(self) -> t.Generator[str, None, None]:
|
||||
dataset = KajiwotoDataset()
|
||||
for episode in dataset:
|
||||
turns: t.List[str] = []
|
||||
for turn in episode:
|
||||
turns.append(f"{USER_PREFIX}: {turn.user_message}")
|
||||
turns.append(f"{BOT_PREFIX}: {turn.bot_response}")
|
||||
|
||||
string = "\n".join(turns)
|
||||
processed_string = replace_special_tokens_in(string)
|
||||
|
||||
for generated_string in generate_variants_for(processed_string):
|
||||
yield generated_string
|
Loading…
Reference in New Issue