refactor: adjust Kajiwoto modules to use the proper prompt constants

This commit is contained in:
11b 2022-12-23 16:37:47 -03:00
parent 60e0a21a3c
commit a16673ebe0
2 changed files with 12 additions and 8 deletions

View File

@ -1,13 +1,12 @@
import typing as t import typing as t
from waifu.core.consts import PromptConstants
from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for, from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
replace_special_tokens_in) replace_special_tokens_in)
from waifu.modules import BaseModule from waifu.modules import BaseModule
from waifu.modules.kajiwoto_vdm import BOT_PREFIX
from waifu.utils.strings import uppercase from waifu.utils.strings import uppercase
USER_PREFIX = "Person 1"
BOT_PREFIX = "Person 2"
class KajiwotoPDM(BaseModule): class KajiwotoPDM(BaseModule):
'''A Persona Dialogue Module powered by the Kajiwoto dataset.''' '''A Persona Dialogue Module powered by the Kajiwoto dataset.'''
@ -29,15 +28,17 @@ class KajiwotoPDM(BaseModule):
description_string = metadata.description.replace("\n", description_string = metadata.description.replace("\n",
" ").replace( " ").replace(
" ", " ") " ", " ")
turns.append(f"{BOT_PREFIX}'s Description: {description_string}") turns.append(
turns.append(f"{BOT_PREFIX}'s Persona: {persona_string}") f"{PromptConstants.pdm_prefix_for(BOT_PREFIX)}: {description_string}\n{persona_string}"
)
# Empty turn to have a line break separating description/persona # Empty turn to have a line break separating description/persona
# and the actual messages. # and the actual messages.
turns.append("") turns.append("")
for turn in episode: for turn in episode:
turns.append(f"{USER_PREFIX}: {turn.user_message}") turns.append(
f"{PromptConstants.USER_PREFIX}: {turn.user_message}")
turns.append(f"{BOT_PREFIX}: {turn.bot_response}") turns.append(f"{BOT_PREFIX}: {turn.bot_response}")
string = "\n".join(turns) string = "\n".join(turns)

View File

@ -1,10 +1,12 @@
import typing as t import typing as t
from waifu.core.consts import PromptConstants
from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for, from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
replace_special_tokens_in) replace_special_tokens_in)
from waifu.modules import BaseModule from waifu.modules import BaseModule
USER_PREFIX = "Person 1" # TODO(11b): Figure out if we can do something better instead of hardcoding a
# fake name.
BOT_PREFIX = "Person 2" BOT_PREFIX = "Person 2"
@ -16,7 +18,8 @@ class KajiwotoVDM(BaseModule):
for episode in dataset: for episode in dataset:
turns: t.List[str] = [] turns: t.List[str] = []
for turn in episode: for turn in episode:
turns.append(f"{USER_PREFIX}: {turn.user_message}") turns.append(
f"{PromptConstants.USER_PREFIX}: {turn.user_message}")
turns.append(f"{BOT_PREFIX}: {turn.bot_response}") turns.append(f"{BOT_PREFIX}: {turn.bot_response}")
string = "\n".join(turns) string = "\n".join(turns)