toolbox/waifu/datasets/characterai.py

76 lines
2.2 KiB
Python

import json
import os
import typing as t
from dataclasses import dataclass
import mashumaro
from waifu.datasets import BaseDataset
from waifu.utils.dataset import get_data_path
@dataclass(frozen=True)
class CaiBotInfo(mashumaro.DataClassDictMixin):
name: str
title: str
description: str
greeting: str
@dataclass(frozen=True)
class CaiChat:
# First message is the bot's greeting, the one afterwards is the user.
messages: t.List[str]
bot_info: CaiBotInfo
class CharacterAiDataset(BaseDataset[CaiChat]):
'''Dataset for CharacterAI dumps.'''
def generator(self) -> t.Generator[CaiChat, None, None]:
for folder in _enumerate_bot_folders():
info_path = os.path.join(folder, "info.json")
histories_path = os.path.join(folder, "histories.json")
with open(info_path, "r", encoding="utf-8") as info_file, \
open(histories_path, "r", encoding="utf-8") as histories_file:
info_json = json.load(info_file)
histories_json = json.load(histories_file)
bot_info = CaiBotInfo.from_dict(info_json["character"])
for history_dict in histories_json["histories"]:
messages = _messages_from_dict(history_dict["msgs"])
yield CaiChat(bot_info=bot_info, messages=messages)
#
# Private helpers.
#
def _enumerate_bot_folders() -> list[str]:
'''Returns a list of folders available in the CAI data folder.'''
dataset_path = get_data_path(dataset_name="test_characterai_dumps")
items = os.listdir(dataset_path)
folders: list[str] = []
for item in items:
item_path = os.path.join(dataset_path, item)
if os.path.isfile(item_path):
# We only care about folders.
continue
absolute_folder_path = os.path.abspath(os.path.join(dataset_path, item))
folders.append(absolute_folder_path)
return folders
def _messages_from_dict(msgs_dict: list[dict[str, t.Any]]) -> list[str]:
'''Builds an array of messages from an entry from the `histories` JSON.'''
messages: list[str] = []
for raw_message in msgs_dict:
messages.append(raw_message["text"])
return messages