diff --git a/waifu/core/consts.py b/waifu/core/consts.py index 9f9c185..d007e94 100644 --- a/waifu/core/consts.py +++ b/waifu/core/consts.py @@ -7,6 +7,9 @@ class PromptConstants: # Token to be replaced with the user's display name within bot messages. USER_TOKEN = "" + # Token to be replaced by the bot's name. + BOT_TOKEN = "" + # Global target word count. The word count is chosen in such a way that we # can fit all the required prompt trickery into the model's input, but still # leave enough space for the user's input message and the infernce result. diff --git a/waifu/datasets/kajiwoto.py b/waifu/datasets/kajiwoto.py index 3bf09f9..78fb0c1 100644 --- a/waifu/datasets/kajiwoto.py +++ b/waifu/datasets/kajiwoto.py @@ -4,6 +4,7 @@ import os import re import typing as t from dataclasses import dataclass +from waifu.core.consts import PromptConstants from waifu.datasets import BaseDataset from waifu.utils.dataset import get_data_path @@ -156,10 +157,13 @@ def replace_special_tokens_in(string: str) -> str: Replaces known special tokens (e.g.: `%{name}`) with their expected equivalents. ''' + string = string.replace("%{name}", PromptConstants.USER_TOKEN) + string = string.replace("%{kajiname}", PromptConstants.BOT_TOKEN) + if (match := re.search(KAJIWOTO_VARIANT_REGEX, string)) is not None: special_token = match.groups()[0] if '|' not in special_token and special_token not in seen_special_tokens: - logger.debug("Unhandled Kajiwoto token: %s", special_token) + logger.warning("Unhandled Kajiwoto token: %s", special_token) seen_special_tokens.add(special_token) if (scene_match := re.search(r"#scene=(.+?)\b", string)) is not None: @@ -168,9 +172,12 @@ def replace_special_tokens_in(string: str) -> str: logger.debug("Unhandled Kajiwoto scene: %s", seen_scene) seen_scenes.add(seen_scene) - # TODO: There's lots of these which I haven't handled at all. E.g.: - # %{pronoun} (before and after a dot, so careful with caps), %{name}, - # %{kajiname}, #scene=SOMETHING, ... + # Drop the scene marker. Maybe we can use it for something useful, but + # I can't think of anything at the moment. + string = string.replace(f"#scene={seen_scene}", "").strip() + + # TODO: There's a few of these which I haven't handled yet. E.g.: + # %{pronoun} (before and after a dot, so careful with caps). return string diff --git a/waifu/modules/kajiwoto_vdm.py b/waifu/modules/kajiwoto_vdm.py index 1821cc0..657aef3 100644 --- a/waifu/modules/kajiwoto_vdm.py +++ b/waifu/modules/kajiwoto_vdm.py @@ -5,10 +5,6 @@ from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for, replace_special_tokens_in) from waifu.modules import BaseModule -# TODO(11b): Figure out if we can do something better instead of hardcoding a -# fake name. -BOT_PREFIX = "Person 2" - class KajiwotoVDM(BaseModule): '''A Vanilla Dialogue Module powered by the Kajiwoto dataset.''' @@ -20,7 +16,8 @@ class KajiwotoVDM(BaseModule): for turn in episode: turns.append( f"{PromptConstants.USER_PREFIX}: {turn.user_message}") - turns.append(f"{BOT_PREFIX}: {turn.bot_response}") + turns.append( + f"{PromptConstants.BOT_TOKEN}: {turn.bot_response}") string = "\n".join(turns) processed_string = replace_special_tokens_in(string)