feat: super early CAI dataset and module for testing

This commit is contained in:
11b 2022-12-18 17:26:16 -03:00
parent 657cbe1d61
commit f82b4ea913
3 changed files with 102 additions and 1 deletions

View File

@ -0,0 +1,75 @@
import json
import os
import typing as t
from dataclasses import dataclass
import mashumaro
from waifu.datasets import BaseDataset
from waifu.utils.dataset import get_data_path
@dataclass(frozen=True)
class CaiBotInfo(mashumaro.DataClassDictMixin):
name: str
title: str
description: str
greeting: str
@dataclass(frozen=True)
class CaiChat:
# First message is the bot's greeting, the one afterwards is the user.
messages: t.List[str]
bot_info: CaiBotInfo
class CharacterAiDataset(BaseDataset[CaiChat]):
'''Dataset for CharacterAI dumps.'''
def generator(self) -> t.Generator[CaiChat, None, None]:
for folder in _enumerate_bot_folders():
info_path = os.path.join(folder, "info.json")
histories_path = os.path.join(folder, "histories.json")
with open(info_path, "r", encoding="utf-8") as info_file, \
open(histories_path, "r", encoding="utf-8") as histories_file:
info_json = json.load(info_file)
histories_json = json.load(histories_file)
bot_info = CaiBotInfo.from_dict(info_json["character"])
for history_dict in histories_json["histories"]:
messages = _messages_from_dict(history_dict["msgs"])
yield CaiChat(bot_info=bot_info, messages=messages)
#
# Private helpers.
#
def _enumerate_bot_folders() -> list[str]:
'''Returns a list of folders available in the CAI data folder.'''
dataset_path = get_data_path(dataset_name="test_characterai_dumps")
items = os.listdir(dataset_path)
folders: list[str] = []
for item in items:
item_path = os.path.join(dataset_path, item)
if os.path.isfile(item_path):
# We only care about folders.
continue
absolute_folder_path = os.path.abspath(os.path.join(dataset_path, item))
folders.append(absolute_folder_path)
return folders
def _messages_from_dict(msgs_dict: list[dict[str, t.Any]]) -> list[str]:
'''Builds an array of messages from an entry from the `histories` JSON.'''
messages: list[str] = []
for raw_message in msgs_dict:
messages.append(raw_message["text"])
return messages

View File

@ -0,0 +1,25 @@
import typing as t
from waifu.datasets.characterai import CharacterAiDataset
from waifu.modules import BaseModule
USER_PREFIX = "You"
class CharacterAiPDM(BaseModule):
'''A Persona Dialogue Module powered by CharacterAI data.'''
def generator(self) -> t.Generator[str, None, None]:
for chat in CharacterAiDataset():
description_string = f"{chat.bot_info.name}'s Description: {chat.bot_info.description}"
# Empty turn to separate description from the messages.
turns = [description_string, ""]
for idx, raw_message in enumerate(chat.messages):
if idx % 2 == 0:
message = f"{chat.bot_info.name}: {raw_message}"
else:
message = f"{USER_PREFIX}: {raw_message}"
turns.append(message)
yield "\n".join(turns)

View File

@ -14,8 +14,9 @@ from waifu.modules import BaseModule
# TODO(11b): Needs manual maintenance ot keep up-to-date. Consider doing some
# metaprogramming trickery to build this list out instead.
DEFAULT_MODULE_LIST = [
"characterai_pdm:CharacterAiPDM",
"kajiwoto_pdm:KajiwotoPDM",
"kajiwoto_vdm:KajiwotoVDM",
# "kajiwoto_vdm:KajiwotoVDM",
"light_dialogue_vdm:LightDialogueVDM",
]
DEFAULT_MODULES_STRING = ",".join(DEFAULT_MODULE_LIST)