From a16673ebe0fff5262860c73de8c441303667f1fc Mon Sep 17 00:00:00 2001 From: 0x000011b <0x000011b@waifu.club> Date: Fri, 23 Dec 2022 16:37:47 -0300 Subject: [PATCH] refactor: adjust Kajiwoto modules to use the proper prompt constants --- waifu/modules/kajiwoto_pdm.py | 13 +++++++------ waifu/modules/kajiwoto_vdm.py | 7 +++++-- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/waifu/modules/kajiwoto_pdm.py b/waifu/modules/kajiwoto_pdm.py index 3b381b3..99fd6ef 100644 --- a/waifu/modules/kajiwoto_pdm.py +++ b/waifu/modules/kajiwoto_pdm.py @@ -1,13 +1,12 @@ import typing as t +from waifu.core.consts import PromptConstants from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for, replace_special_tokens_in) from waifu.modules import BaseModule +from waifu.modules.kajiwoto_vdm import BOT_PREFIX from waifu.utils.strings import uppercase -USER_PREFIX = "Person 1" -BOT_PREFIX = "Person 2" - class KajiwotoPDM(BaseModule): '''A Persona Dialogue Module powered by the Kajiwoto dataset.''' @@ -29,15 +28,17 @@ class KajiwotoPDM(BaseModule): description_string = metadata.description.replace("\n", " ").replace( " ", " ") - turns.append(f"{BOT_PREFIX}'s Description: {description_string}") - turns.append(f"{BOT_PREFIX}'s Persona: {persona_string}") + turns.append( + f"{PromptConstants.pdm_prefix_for(BOT_PREFIX)}: {description_string}\n{persona_string}" + ) # Empty turn to have a line break separating description/persona # and the actual messages. turns.append("") for turn in episode: - turns.append(f"{USER_PREFIX}: {turn.user_message}") + turns.append( + f"{PromptConstants.USER_PREFIX}: {turn.user_message}") turns.append(f"{BOT_PREFIX}: {turn.bot_response}") string = "\n".join(turns) diff --git a/waifu/modules/kajiwoto_vdm.py b/waifu/modules/kajiwoto_vdm.py index 9770963..1821cc0 100644 --- a/waifu/modules/kajiwoto_vdm.py +++ b/waifu/modules/kajiwoto_vdm.py @@ -1,10 +1,12 @@ import typing as t +from waifu.core.consts import PromptConstants from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for, replace_special_tokens_in) from waifu.modules import BaseModule -USER_PREFIX = "Person 1" +# TODO(11b): Figure out if we can do something better instead of hardcoding a +# fake name. BOT_PREFIX = "Person 2" @@ -16,7 +18,8 @@ class KajiwotoVDM(BaseModule): for episode in dataset: turns: t.List[str] = [] for turn in episode: - turns.append(f"{USER_PREFIX}: {turn.user_message}") + turns.append( + f"{PromptConstants.USER_PREFIX}: {turn.user_message}") turns.append(f"{BOT_PREFIX}: {turn.bot_response}") string = "\n".join(turns)