refactor: adjust Kajiwoto modules to use the proper prompt constants
This commit is contained in:
parent
60e0a21a3c
commit
a16673ebe0
|
@ -1,13 +1,12 @@
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
|
from waifu.core.consts import PromptConstants
|
||||||
from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
|
from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
|
||||||
replace_special_tokens_in)
|
replace_special_tokens_in)
|
||||||
from waifu.modules import BaseModule
|
from waifu.modules import BaseModule
|
||||||
|
from waifu.modules.kajiwoto_vdm import BOT_PREFIX
|
||||||
from waifu.utils.strings import uppercase
|
from waifu.utils.strings import uppercase
|
||||||
|
|
||||||
USER_PREFIX = "Person 1"
|
|
||||||
BOT_PREFIX = "Person 2"
|
|
||||||
|
|
||||||
|
|
||||||
class KajiwotoPDM(BaseModule):
|
class KajiwotoPDM(BaseModule):
|
||||||
'''A Persona Dialogue Module powered by the Kajiwoto dataset.'''
|
'''A Persona Dialogue Module powered by the Kajiwoto dataset.'''
|
||||||
|
@ -29,15 +28,17 @@ class KajiwotoPDM(BaseModule):
|
||||||
description_string = metadata.description.replace("\n",
|
description_string = metadata.description.replace("\n",
|
||||||
" ").replace(
|
" ").replace(
|
||||||
" ", " ")
|
" ", " ")
|
||||||
turns.append(f"{BOT_PREFIX}'s Description: {description_string}")
|
turns.append(
|
||||||
turns.append(f"{BOT_PREFIX}'s Persona: {persona_string}")
|
f"{PromptConstants.pdm_prefix_for(BOT_PREFIX)}: {description_string}\n{persona_string}"
|
||||||
|
)
|
||||||
|
|
||||||
# Empty turn to have a line break separating description/persona
|
# Empty turn to have a line break separating description/persona
|
||||||
# and the actual messages.
|
# and the actual messages.
|
||||||
turns.append("")
|
turns.append("")
|
||||||
|
|
||||||
for turn in episode:
|
for turn in episode:
|
||||||
turns.append(f"{USER_PREFIX}: {turn.user_message}")
|
turns.append(
|
||||||
|
f"{PromptConstants.USER_PREFIX}: {turn.user_message}")
|
||||||
turns.append(f"{BOT_PREFIX}: {turn.bot_response}")
|
turns.append(f"{BOT_PREFIX}: {turn.bot_response}")
|
||||||
|
|
||||||
string = "\n".join(turns)
|
string = "\n".join(turns)
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
import typing as t
|
import typing as t
|
||||||
|
|
||||||
|
from waifu.core.consts import PromptConstants
|
||||||
from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
|
from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
|
||||||
replace_special_tokens_in)
|
replace_special_tokens_in)
|
||||||
from waifu.modules import BaseModule
|
from waifu.modules import BaseModule
|
||||||
|
|
||||||
USER_PREFIX = "Person 1"
|
# TODO(11b): Figure out if we can do something better instead of hardcoding a
|
||||||
|
# fake name.
|
||||||
BOT_PREFIX = "Person 2"
|
BOT_PREFIX = "Person 2"
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,7 +18,8 @@ class KajiwotoVDM(BaseModule):
|
||||||
for episode in dataset:
|
for episode in dataset:
|
||||||
turns: t.List[str] = []
|
turns: t.List[str] = []
|
||||||
for turn in episode:
|
for turn in episode:
|
||||||
turns.append(f"{USER_PREFIX}: {turn.user_message}")
|
turns.append(
|
||||||
|
f"{PromptConstants.USER_PREFIX}: {turn.user_message}")
|
||||||
turns.append(f"{BOT_PREFIX}: {turn.bot_response}")
|
turns.append(f"{BOT_PREFIX}: {turn.bot_response}")
|
||||||
|
|
||||||
string = "\n".join(turns)
|
string = "\n".join(turns)
|
||||||
|
|
Loading…
Reference in New Issue