feat: improve handling of special tokens in the Kajiwoto dataset
This commit is contained in:
parent
b95b30cf88
commit
96b41dee60
|
@ -7,6 +7,9 @@ class PromptConstants:
|
|||
# Token to be replaced with the user's display name within bot messages.
|
||||
USER_TOKEN = "<USER>"
|
||||
|
||||
# Token to be replaced by the bot's name.
|
||||
BOT_TOKEN = "<BOT>"
|
||||
|
||||
# Global target word count. The word count is chosen in such a way that we
|
||||
# can fit all the required prompt trickery into the model's input, but still
|
||||
# leave enough space for the user's input message and the infernce result.
|
||||
|
|
|
@ -4,6 +4,7 @@ import os
|
|||
import re
|
||||
import typing as t
|
||||
from dataclasses import dataclass
|
||||
from waifu.core.consts import PromptConstants
|
||||
|
||||
from waifu.datasets import BaseDataset
|
||||
from waifu.utils.dataset import get_data_path
|
||||
|
@ -156,10 +157,13 @@ def replace_special_tokens_in(string: str) -> str:
|
|||
Replaces known special tokens (e.g.: `%{name}`) with their expected
|
||||
equivalents.
|
||||
'''
|
||||
string = string.replace("%{name}", PromptConstants.USER_TOKEN)
|
||||
string = string.replace("%{kajiname}", PromptConstants.BOT_TOKEN)
|
||||
|
||||
if (match := re.search(KAJIWOTO_VARIANT_REGEX, string)) is not None:
|
||||
special_token = match.groups()[0]
|
||||
if '|' not in special_token and special_token not in seen_special_tokens:
|
||||
logger.debug("Unhandled Kajiwoto token: %s", special_token)
|
||||
logger.warning("Unhandled Kajiwoto token: %s", special_token)
|
||||
seen_special_tokens.add(special_token)
|
||||
|
||||
if (scene_match := re.search(r"#scene=(.+?)\b", string)) is not None:
|
||||
|
@ -168,9 +172,12 @@ def replace_special_tokens_in(string: str) -> str:
|
|||
logger.debug("Unhandled Kajiwoto scene: %s", seen_scene)
|
||||
seen_scenes.add(seen_scene)
|
||||
|
||||
# TODO: There's lots of these which I haven't handled at all. E.g.:
|
||||
# %{pronoun} (before and after a dot, so careful with caps), %{name},
|
||||
# %{kajiname}, #scene=SOMETHING, ...
|
||||
# Drop the scene marker. Maybe we can use it for something useful, but
|
||||
# I can't think of anything at the moment.
|
||||
string = string.replace(f"#scene={seen_scene}", "").strip()
|
||||
|
||||
# TODO: There's a few of these which I haven't handled yet. E.g.:
|
||||
# %{pronoun} (before and after a dot, so careful with caps).
|
||||
return string
|
||||
|
||||
|
||||
|
|
|
@ -5,10 +5,6 @@ from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
|
|||
replace_special_tokens_in)
|
||||
from waifu.modules import BaseModule
|
||||
|
||||
# TODO(11b): Figure out if we can do something better instead of hardcoding a
|
||||
# fake name.
|
||||
BOT_PREFIX = "Person 2"
|
||||
|
||||
|
||||
class KajiwotoVDM(BaseModule):
|
||||
'''A Vanilla Dialogue Module powered by the Kajiwoto dataset.'''
|
||||
|
@ -20,7 +16,8 @@ class KajiwotoVDM(BaseModule):
|
|||
for turn in episode:
|
||||
turns.append(
|
||||
f"{PromptConstants.USER_PREFIX}: {turn.user_message}")
|
||||
turns.append(f"{BOT_PREFIX}: {turn.bot_response}")
|
||||
turns.append(
|
||||
f"{PromptConstants.BOT_TOKEN}: {turn.bot_response}")
|
||||
|
||||
string = "\n".join(turns)
|
||||
processed_string = replace_special_tokens_in(string)
|
||||
|
|
Loading…
Reference in New Issue