toolbox/toolbox/modules/kajiwoto_pdm.py

49 lines
2.0 KiB
Python

import typing as t
from toolbox.core.consts import PromptConstants
from toolbox.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
replace_special_tokens_in)
from toolbox.modules import BaseModule
from toolbox.utils.strings import uppercase
class KajiwotoPDM(BaseModule):
'''A Persona Dialogue Module powered by the Kajiwoto dataset.'''
def generator(self) -> t.Generator[str, None, None]:
dataset = KajiwotoDataset()
for episode in dataset:
turns: list[str] = []
metadata = dataset.get_metadata_for_bot(episode[0].bot_id)
# `metadata.personalities` is in a format like: `[["friendly", "20.32"]]`
# but we want that "phrased" closer to natural language, so we build
# `persona_string` to take care of that.
personality_descriptors = [x[0] for x in metadata.personalities]
persona_string = ". ".join(
[uppercase(x) for x in personality_descriptors]) + "."
description_string = metadata.description.replace("\n",
" ").replace(
" ", " ")
turns.append(
f"{PromptConstants.pdm_prefix_for(PromptConstants.BOT_TOKEN)}: {description_string}\n{persona_string}"
)
# Empty turn to have a line break separating description/persona
# and the actual messages.
turns.append("")
for turn in episode:
turns.append(
f"{PromptConstants.USER_PREFIX}: {turn.user_message}")
turns.append(
f"{PromptConstants.BOT_TOKEN}: {turn.bot_response}")
string = "\n".join(turns)
processed_string = replace_special_tokens_in(string)
for generated_string in generate_variants_for(processed_string):
yield generated_string