162 lines
5.1 KiB
Python
162 lines
5.1 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
import typing as t
|
|
from dataclasses import dataclass
|
|
|
|
from toolbox.datasets import BaseDataset
|
|
from toolbox.utils.dataset import get_data_path
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CaiBotInfo:
|
|
name: str
|
|
title: str
|
|
description: str | None
|
|
greeting: str
|
|
|
|
# Optional because it might be private.
|
|
definitions: str | None
|
|
|
|
# Useful for when several bots have the same name - we can tell them apart
|
|
# by their external_id.
|
|
external_id: str
|
|
|
|
# There's also categories, but I'm ignoring them for now since I don't think
|
|
# they'll be of much use.
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CaiMessage:
|
|
is_human: bool
|
|
text: str
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CaiChat:
|
|
# First message is always the bot's greeting.
|
|
messages: list[CaiMessage]
|
|
bot: CaiBotInfo
|
|
|
|
|
|
class CharacterAiDataset(BaseDataset[CaiChat]):
|
|
'''Dataset for CharacterAI dumps.'''
|
|
|
|
def generator(self) -> t.Generator[CaiChat, None, None]:
|
|
bot_id_to_info_dict = {}
|
|
|
|
# Do a first run through all the files to load all the definitions and
|
|
# descriptions.
|
|
for data in _available_json_data():
|
|
if not _is_definition_data(data):
|
|
continue
|
|
|
|
bot_info = _bot_info_from_dict(data["character"])
|
|
bot_id_to_info_dict[bot_info.external_id] = bot_info
|
|
|
|
# Now do a second pass, to actually handle chat histories/messages.
|
|
for data in _available_json_data():
|
|
if _is_definition_data(data):
|
|
continue
|
|
|
|
# Prefer grabbing bot info from a Character Editor dump, if it
|
|
# exists. Fall back to public data otherwise.
|
|
bot_id = data["info"]["character"]["external_id"]
|
|
bot_info = bot_id_to_info_dict.get(
|
|
bot_id, _bot_info_from_dict(data["info"]["character"]))
|
|
|
|
for history_dict in data["histories"]["histories"]:
|
|
messages = _messages_from_dict(history_dict["msgs"])
|
|
yield CaiChat(bot=bot_info, messages=messages)
|
|
|
|
|
|
#
|
|
# Private helpers.
|
|
#
|
|
|
|
|
|
def _enumerate_json_files(root_path: str) -> list[str]:
|
|
'''Returns a list of files available in the given `root_path`.'''
|
|
items = os.listdir(root_path)
|
|
|
|
files: list[str] = []
|
|
for item in items:
|
|
item_path = os.path.join(root_path, item)
|
|
if not os.path.isfile(item_path) or not item_path.endswith(".json"):
|
|
# We only care about JSON files.
|
|
continue
|
|
|
|
absolute_file_path = os.path.abspath(os.path.join(root_path, item))
|
|
files.append(absolute_file_path)
|
|
|
|
return files
|
|
|
|
|
|
def _available_json_data() -> t.Generator[dict[str, t.Any], None, None]:
|
|
'''
|
|
Yields all available JSON data, parsed from the files in the CharacterAI
|
|
data folder.
|
|
'''
|
|
dataset_path = get_data_path(dataset_name="characterai")
|
|
|
|
for folder in ["public", "private"]:
|
|
folder_path = os.path.join(dataset_path, folder)
|
|
for json_file_path in _enumerate_json_files(folder_path):
|
|
with open(json_file_path, "r", encoding="utf-8") as json_file:
|
|
try:
|
|
yield json.load(json_file)
|
|
except json.decoder.JSONDecodeError as ex:
|
|
logger.error("Failed to parse %s: %s", json_file_path, ex)
|
|
|
|
|
|
def _bot_info_from_dict(info_dict: dict[str, t.Any]) -> CaiBotInfo:
|
|
'''Builds a CaiBotInfo object from the `character` field in the JSON.'''
|
|
return CaiBotInfo(
|
|
name=info_dict["name"],
|
|
title=info_dict["title"],
|
|
# This comes in as an empty string instead of `null` in the JSON when
|
|
# it's not defined for some reason, so we cast to None here for clarity.
|
|
description=info_dict["description"] or None,
|
|
greeting=info_dict["greeting"],
|
|
definitions=info_dict.get("definition"),
|
|
external_id=info_dict["external_id"],
|
|
)
|
|
|
|
|
|
def _messages_from_dict(msgs_dict: list[dict[str, t.Any]]) -> list[CaiMessage]:
|
|
'''Builds an array of messages from an entry from the `histories` JSON.'''
|
|
messages: list[CaiMessage] = []
|
|
for raw_message in msgs_dict:
|
|
message = CaiMessage(
|
|
text=raw_message["text"],
|
|
is_human=raw_message["src"]["is_human"],
|
|
)
|
|
messages.append(message)
|
|
return messages
|
|
|
|
|
|
def _is_definition_data(dict_from_json: dict[str, t.Any]) -> bool:
|
|
'''
|
|
Figures out whether the given dict (parsed from a JSON file) is a regular
|
|
dump, or a dump from the Character Editor (possibly containing definitions).
|
|
|
|
If it doesn't seem like either, raises a `ValueError` so we can discard bad
|
|
data.
|
|
'''
|
|
keys = list(dict_from_json.keys())
|
|
|
|
# Some people messed with their files so the order of the keys isn't always
|
|
# the same, so we sort for consistency.
|
|
keys.sort()
|
|
if keys == ["character"]:
|
|
return True
|
|
elif keys == ["character", "user__username"]:
|
|
return True
|
|
elif keys == ["histories", "info"]:
|
|
return False
|
|
else:
|
|
print(dict_from_json)
|
|
raise ValueError(f"Unexpected keys found in CAI dump JSON file: {keys}")
|