toolbox/toolbox/datasets/characterai.py

162 lines
5.1 KiB
Python

import json
import logging
import os
import typing as t
from dataclasses import dataclass
from toolbox.datasets import BaseDataset
from toolbox.utils.dataset import get_data_path
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class CaiBotInfo:
name: str
title: str
description: str | None
greeting: str
# Optional because it might be private.
definitions: str | None
# Useful for when several bots have the same name - we can tell them apart
# by their external_id.
external_id: str
# There's also categories, but I'm ignoring them for now since I don't think
# they'll be of much use.
@dataclass(frozen=True)
class CaiMessage:
is_human: bool
text: str
@dataclass(frozen=True)
class CaiChat:
# First message is always the bot's greeting.
messages: list[CaiMessage]
bot: CaiBotInfo
class CharacterAiDataset(BaseDataset[CaiChat]):
'''Dataset for CharacterAI dumps.'''
def generator(self) -> t.Generator[CaiChat, None, None]:
bot_id_to_info_dict = {}
# Do a first run through all the files to load all the definitions and
# descriptions.
for data in _available_json_data():
if not _is_definition_data(data):
continue
bot_info = _bot_info_from_dict(data["character"])
bot_id_to_info_dict[bot_info.external_id] = bot_info
# Now do a second pass, to actually handle chat histories/messages.
for data in _available_json_data():
if _is_definition_data(data):
continue
# Prefer grabbing bot info from a Character Editor dump, if it
# exists. Fall back to public data otherwise.
bot_id = data["info"]["character"]["external_id"]
bot_info = bot_id_to_info_dict.get(
bot_id, _bot_info_from_dict(data["info"]["character"]))
for history_dict in data["histories"]["histories"]:
messages = _messages_from_dict(history_dict["msgs"])
yield CaiChat(bot=bot_info, messages=messages)
#
# Private helpers.
#
def _enumerate_json_files(root_path: str) -> list[str]:
'''Returns a list of files available in the given `root_path`.'''
items = os.listdir(root_path)
files: list[str] = []
for item in items:
item_path = os.path.join(root_path, item)
if not os.path.isfile(item_path) or not item_path.endswith(".json"):
# We only care about JSON files.
continue
absolute_file_path = os.path.abspath(os.path.join(root_path, item))
files.append(absolute_file_path)
return files
def _available_json_data() -> t.Generator[dict[str, t.Any], None, None]:
'''
Yields all available JSON data, parsed from the files in the CharacterAI
data folder.
'''
dataset_path = get_data_path(dataset_name="characterai")
for folder in ["public", "private"]:
folder_path = os.path.join(dataset_path, folder)
for json_file_path in _enumerate_json_files(folder_path):
with open(json_file_path, "r", encoding="utf-8") as json_file:
try:
yield json.load(json_file)
except json.decoder.JSONDecodeError as ex:
logger.error("Failed to parse %s: %s", json_file_path, ex)
def _bot_info_from_dict(info_dict: dict[str, t.Any]) -> CaiBotInfo:
'''Builds a CaiBotInfo object from the `character` field in the JSON.'''
return CaiBotInfo(
name=info_dict["name"],
title=info_dict["title"],
# This comes in as an empty string instead of `null` in the JSON when
# it's not defined for some reason, so we cast to None here for clarity.
description=info_dict["description"] or None,
greeting=info_dict["greeting"],
definitions=info_dict.get("definition"),
external_id=info_dict["external_id"],
)
def _messages_from_dict(msgs_dict: list[dict[str, t.Any]]) -> list[CaiMessage]:
'''Builds an array of messages from an entry from the `histories` JSON.'''
messages: list[CaiMessage] = []
for raw_message in msgs_dict:
message = CaiMessage(
text=raw_message["text"],
is_human=raw_message["src"]["is_human"],
)
messages.append(message)
return messages
def _is_definition_data(dict_from_json: dict[str, t.Any]) -> bool:
'''
Figures out whether the given dict (parsed from a JSON file) is a regular
dump, or a dump from the Character Editor (possibly containing definitions).
If it doesn't seem like either, raises a `ValueError` so we can discard bad
data.
'''
keys = list(dict_from_json.keys())
# Some people messed with their files so the order of the keys isn't always
# the same, so we sort for consistency.
keys.sort()
if keys == ["character"]:
return True
elif keys == ["character", "user__username"]:
return True
elif keys == ["histories", "info"]:
return False
else:
print(dict_from_json)
raise ValueError(f"Unexpected keys found in CAI dump JSON file: {keys}")