feat: improve handling of special tokens in the Kajiwoto dataset

This commit is contained in:
11b 2022-12-27 12:46:57 -03:00
parent b95b30cf88
commit 96b41dee60
3 changed files with 16 additions and 9 deletions

View File

@ -7,6 +7,9 @@ class PromptConstants:
# Token to be replaced with the user's display name within bot messages.
USER_TOKEN = "<USER>"
# Token to be replaced by the bot's name.
BOT_TOKEN = "<BOT>"
# Global target word count. The word count is chosen in such a way that we
# can fit all the required prompt trickery into the model's input, but still
# leave enough space for the user's input message and the infernce result.

View File

@ -4,6 +4,7 @@ import os
import re
import typing as t
from dataclasses import dataclass
from waifu.core.consts import PromptConstants
from waifu.datasets import BaseDataset
from waifu.utils.dataset import get_data_path
@ -156,10 +157,13 @@ def replace_special_tokens_in(string: str) -> str:
Replaces known special tokens (e.g.: `%{name}`) with their expected
equivalents.
'''
string = string.replace("%{name}", PromptConstants.USER_TOKEN)
string = string.replace("%{kajiname}", PromptConstants.BOT_TOKEN)
if (match := re.search(KAJIWOTO_VARIANT_REGEX, string)) is not None:
special_token = match.groups()[0]
if '|' not in special_token and special_token not in seen_special_tokens:
logger.debug("Unhandled Kajiwoto token: %s", special_token)
logger.warning("Unhandled Kajiwoto token: %s", special_token)
seen_special_tokens.add(special_token)
if (scene_match := re.search(r"#scene=(.+?)\b", string)) is not None:
@ -168,9 +172,12 @@ def replace_special_tokens_in(string: str) -> str:
logger.debug("Unhandled Kajiwoto scene: %s", seen_scene)
seen_scenes.add(seen_scene)
# TODO: There's lots of these which I haven't handled at all. E.g.:
# %{pronoun} (before and after a dot, so careful with caps), %{name},
# %{kajiname}, #scene=SOMETHING, ...
# Drop the scene marker. Maybe we can use it for something useful, but
# I can't think of anything at the moment.
string = string.replace(f"#scene={seen_scene}", "").strip()
# TODO: There's a few of these which I haven't handled yet. E.g.:
# %{pronoun} (before and after a dot, so careful with caps).
return string

View File

@ -5,10 +5,6 @@ from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
replace_special_tokens_in)
from waifu.modules import BaseModule
# TODO(11b): Figure out if we can do something better instead of hardcoding a
# fake name.
BOT_PREFIX = "Person 2"
class KajiwotoVDM(BaseModule):
'''A Vanilla Dialogue Module powered by the Kajiwoto dataset.'''
@ -20,7 +16,8 @@ class KajiwotoVDM(BaseModule):
for turn in episode:
turns.append(
f"{PromptConstants.USER_PREFIX}: {turn.user_message}")
turns.append(f"{BOT_PREFIX}: {turn.bot_response}")
turns.append(
f"{PromptConstants.BOT_TOKEN}: {turn.bot_response}")
string = "\n".join(turns)
processed_string = replace_special_tokens_in(string)