diff --git a/waifu/modules/kajiwoto_pdm.py b/waifu/modules/kajiwoto_pdm.py index 3b381b3..99fd6ef 100644 --- a/waifu/modules/kajiwoto_pdm.py +++ b/waifu/modules/kajiwoto_pdm.py @@ -1,13 +1,12 @@ import typing as t +from waifu.core.consts import PromptConstants from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for, replace_special_tokens_in) from waifu.modules import BaseModule +from waifu.modules.kajiwoto_vdm import BOT_PREFIX from waifu.utils.strings import uppercase -USER_PREFIX = "Person 1" -BOT_PREFIX = "Person 2" - class KajiwotoPDM(BaseModule): '''A Persona Dialogue Module powered by the Kajiwoto dataset.''' @@ -29,15 +28,17 @@ class KajiwotoPDM(BaseModule): description_string = metadata.description.replace("\n", " ").replace( " ", " ") - turns.append(f"{BOT_PREFIX}'s Description: {description_string}") - turns.append(f"{BOT_PREFIX}'s Persona: {persona_string}") + turns.append( + f"{PromptConstants.pdm_prefix_for(BOT_PREFIX)}: {description_string}\n{persona_string}" + ) # Empty turn to have a line break separating description/persona # and the actual messages. turns.append("") for turn in episode: - turns.append(f"{USER_PREFIX}: {turn.user_message}") + turns.append( + f"{PromptConstants.USER_PREFIX}: {turn.user_message}") turns.append(f"{BOT_PREFIX}: {turn.bot_response}") string = "\n".join(turns) diff --git a/waifu/modules/kajiwoto_vdm.py b/waifu/modules/kajiwoto_vdm.py index 9770963..1821cc0 100644 --- a/waifu/modules/kajiwoto_vdm.py +++ b/waifu/modules/kajiwoto_vdm.py @@ -1,10 +1,12 @@ import typing as t +from waifu.core.consts import PromptConstants from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for, replace_special_tokens_in) from waifu.modules import BaseModule -USER_PREFIX = "Person 1" +# TODO(11b): Figure out if we can do something better instead of hardcoding a +# fake name. BOT_PREFIX = "Person 2" @@ -16,7 +18,8 @@ class KajiwotoVDM(BaseModule): for episode in dataset: turns: t.List[str] = [] for turn in episode: - turns.append(f"{USER_PREFIX}: {turn.user_message}") + turns.append( + f"{PromptConstants.USER_PREFIX}: {turn.user_message}") turns.append(f"{BOT_PREFIX}: {turn.bot_response}") string = "\n".join(turns)