feat: super early CAI dataset and module for testing
This commit is contained in:
parent
657cbe1d61
commit
f82b4ea913
|
@ -0,0 +1,75 @@
|
|||
import json
|
||||
import os
|
||||
import typing as t
|
||||
from dataclasses import dataclass
|
||||
|
||||
import mashumaro
|
||||
|
||||
from waifu.datasets import BaseDataset
|
||||
from waifu.utils.dataset import get_data_path
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CaiBotInfo(mashumaro.DataClassDictMixin):
|
||||
name: str
|
||||
title: str
|
||||
description: str
|
||||
greeting: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CaiChat:
|
||||
# First message is the bot's greeting, the one afterwards is the user.
|
||||
messages: t.List[str]
|
||||
bot_info: CaiBotInfo
|
||||
|
||||
|
||||
class CharacterAiDataset(BaseDataset[CaiChat]):
|
||||
'''Dataset for CharacterAI dumps.'''
|
||||
|
||||
def generator(self) -> t.Generator[CaiChat, None, None]:
|
||||
for folder in _enumerate_bot_folders():
|
||||
info_path = os.path.join(folder, "info.json")
|
||||
histories_path = os.path.join(folder, "histories.json")
|
||||
|
||||
with open(info_path, "r", encoding="utf-8") as info_file, \
|
||||
open(histories_path, "r", encoding="utf-8") as histories_file:
|
||||
info_json = json.load(info_file)
|
||||
histories_json = json.load(histories_file)
|
||||
|
||||
bot_info = CaiBotInfo.from_dict(info_json["character"])
|
||||
|
||||
for history_dict in histories_json["histories"]:
|
||||
messages = _messages_from_dict(history_dict["msgs"])
|
||||
yield CaiChat(bot_info=bot_info, messages=messages)
|
||||
|
||||
|
||||
#
|
||||
# Private helpers.
|
||||
#
|
||||
|
||||
|
||||
def _enumerate_bot_folders() -> list[str]:
|
||||
'''Returns a list of folders available in the CAI data folder.'''
|
||||
dataset_path = get_data_path(dataset_name="test_characterai_dumps")
|
||||
items = os.listdir(dataset_path)
|
||||
|
||||
folders: list[str] = []
|
||||
for item in items:
|
||||
item_path = os.path.join(dataset_path, item)
|
||||
if os.path.isfile(item_path):
|
||||
# We only care about folders.
|
||||
continue
|
||||
|
||||
absolute_folder_path = os.path.abspath(os.path.join(dataset_path, item))
|
||||
folders.append(absolute_folder_path)
|
||||
|
||||
return folders
|
||||
|
||||
|
||||
def _messages_from_dict(msgs_dict: list[dict[str, t.Any]]) -> list[str]:
|
||||
'''Builds an array of messages from an entry from the `histories` JSON.'''
|
||||
messages: list[str] = []
|
||||
for raw_message in msgs_dict:
|
||||
messages.append(raw_message["text"])
|
||||
return messages
|
|
@ -0,0 +1,25 @@
|
|||
import typing as t
|
||||
|
||||
from waifu.datasets.characterai import CharacterAiDataset
|
||||
from waifu.modules import BaseModule
|
||||
|
||||
USER_PREFIX = "You"
|
||||
|
||||
|
||||
class CharacterAiPDM(BaseModule):
|
||||
'''A Persona Dialogue Module powered by CharacterAI data.'''
|
||||
|
||||
def generator(self) -> t.Generator[str, None, None]:
|
||||
for chat in CharacterAiDataset():
|
||||
description_string = f"{chat.bot_info.name}'s Description: {chat.bot_info.description}"
|
||||
# Empty turn to separate description from the messages.
|
||||
turns = [description_string, ""]
|
||||
|
||||
for idx, raw_message in enumerate(chat.messages):
|
||||
if idx % 2 == 0:
|
||||
message = f"{chat.bot_info.name}: {raw_message}"
|
||||
else:
|
||||
message = f"{USER_PREFIX}: {raw_message}"
|
||||
turns.append(message)
|
||||
|
||||
yield "\n".join(turns)
|
|
@ -14,8 +14,9 @@ from waifu.modules import BaseModule
|
|||
# TODO(11b): Needs manual maintenance ot keep up-to-date. Consider doing some
|
||||
# metaprogramming trickery to build this list out instead.
|
||||
DEFAULT_MODULE_LIST = [
|
||||
"characterai_pdm:CharacterAiPDM",
|
||||
"kajiwoto_pdm:KajiwotoPDM",
|
||||
"kajiwoto_vdm:KajiwotoVDM",
|
||||
# "kajiwoto_vdm:KajiwotoVDM",
|
||||
"light_dialogue_vdm:LightDialogueVDM",
|
||||
]
|
||||
DEFAULT_MODULES_STRING = ",".join(DEFAULT_MODULE_LIST)
|
||||
|
|
Loading…
Reference in New Issue