2022-12-18 01:36:33 +01:00
|
|
|
import typing as t
|
|
|
|
|
2022-12-23 20:37:47 +01:00
|
|
|
from waifu.core.consts import PromptConstants
|
2022-12-18 01:36:33 +01:00
|
|
|
from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
|
|
|
|
replace_special_tokens_in)
|
|
|
|
from waifu.modules import BaseModule
|
2022-12-23 20:37:47 +01:00
|
|
|
from waifu.modules.kajiwoto_vdm import BOT_PREFIX
|
2022-12-18 01:36:33 +01:00
|
|
|
from waifu.utils.strings import uppercase
|
|
|
|
|
|
|
|
|
|
|
|
class KajiwotoPDM(BaseModule):
|
|
|
|
'''A Persona Dialogue Module powered by the Kajiwoto dataset.'''
|
|
|
|
|
|
|
|
def generator(self) -> t.Generator[str, None, None]:
|
|
|
|
dataset = KajiwotoDataset()
|
|
|
|
for episode in dataset:
|
|
|
|
turns: list[str] = []
|
|
|
|
|
|
|
|
metadata = dataset.get_metadata_for_bot(episode[0].bot_id)
|
|
|
|
|
|
|
|
# `metadata.personalities` is in a format like: `[["friendly", "20.32"]]`
|
|
|
|
# but we want that "phrased" closer to natural language, so we build
|
|
|
|
# `persona_string` to take care of that.
|
|
|
|
personality_descriptors = [x[0] for x in metadata.personalities]
|
|
|
|
persona_string = ". ".join(
|
|
|
|
[uppercase(x) for x in personality_descriptors]) + "."
|
|
|
|
|
|
|
|
description_string = metadata.description.replace("\n",
|
|
|
|
" ").replace(
|
|
|
|
" ", " ")
|
2022-12-23 20:37:47 +01:00
|
|
|
turns.append(
|
|
|
|
f"{PromptConstants.pdm_prefix_for(BOT_PREFIX)}: {description_string}\n{persona_string}"
|
|
|
|
)
|
2022-12-18 01:36:33 +01:00
|
|
|
|
|
|
|
# Empty turn to have a line break separating description/persona
|
|
|
|
# and the actual messages.
|
|
|
|
turns.append("")
|
|
|
|
|
|
|
|
for turn in episode:
|
2022-12-23 20:37:47 +01:00
|
|
|
turns.append(
|
|
|
|
f"{PromptConstants.USER_PREFIX}: {turn.user_message}")
|
2022-12-18 01:36:33 +01:00
|
|
|
turns.append(f"{BOT_PREFIX}: {turn.bot_response}")
|
|
|
|
|
|
|
|
string = "\n".join(turns)
|
|
|
|
processed_string = replace_special_tokens_in(string)
|
|
|
|
|
|
|
|
for generated_string in generate_variants_for(processed_string):
|
|
|
|
yield generated_string
|