feat: improve handling of special tokens in the Kajiwoto dataset

This commit is contained in:
11b 2022-12-27 12:46:57 -03:00
parent b95b30cf88
commit 96b41dee60
3 changed files with 16 additions and 9 deletions

View File

@ -7,6 +7,9 @@ class PromptConstants:
# Token to be replaced with the user's display name within bot messages. # Token to be replaced with the user's display name within bot messages.
USER_TOKEN = "<USER>" USER_TOKEN = "<USER>"
# Token to be replaced by the bot's name.
BOT_TOKEN = "<BOT>"
# Global target word count. The word count is chosen in such a way that we # Global target word count. The word count is chosen in such a way that we
# can fit all the required prompt trickery into the model's input, but still # can fit all the required prompt trickery into the model's input, but still
# leave enough space for the user's input message and the infernce result. # leave enough space for the user's input message and the infernce result.

View File

@ -4,6 +4,7 @@ import os
import re import re
import typing as t import typing as t
from dataclasses import dataclass from dataclasses import dataclass
from waifu.core.consts import PromptConstants
from waifu.datasets import BaseDataset from waifu.datasets import BaseDataset
from waifu.utils.dataset import get_data_path from waifu.utils.dataset import get_data_path
@ -156,10 +157,13 @@ def replace_special_tokens_in(string: str) -> str:
Replaces known special tokens (e.g.: `%{name}`) with their expected Replaces known special tokens (e.g.: `%{name}`) with their expected
equivalents. equivalents.
''' '''
string = string.replace("%{name}", PromptConstants.USER_TOKEN)
string = string.replace("%{kajiname}", PromptConstants.BOT_TOKEN)
if (match := re.search(KAJIWOTO_VARIANT_REGEX, string)) is not None: if (match := re.search(KAJIWOTO_VARIANT_REGEX, string)) is not None:
special_token = match.groups()[0] special_token = match.groups()[0]
if '|' not in special_token and special_token not in seen_special_tokens: if '|' not in special_token and special_token not in seen_special_tokens:
logger.debug("Unhandled Kajiwoto token: %s", special_token) logger.warning("Unhandled Kajiwoto token: %s", special_token)
seen_special_tokens.add(special_token) seen_special_tokens.add(special_token)
if (scene_match := re.search(r"#scene=(.+?)\b", string)) is not None: if (scene_match := re.search(r"#scene=(.+?)\b", string)) is not None:
@ -168,9 +172,12 @@ def replace_special_tokens_in(string: str) -> str:
logger.debug("Unhandled Kajiwoto scene: %s", seen_scene) logger.debug("Unhandled Kajiwoto scene: %s", seen_scene)
seen_scenes.add(seen_scene) seen_scenes.add(seen_scene)
# TODO: There's lots of these which I haven't handled at all. E.g.: # Drop the scene marker. Maybe we can use it for something useful, but
# %{pronoun} (before and after a dot, so careful with caps), %{name}, # I can't think of anything at the moment.
# %{kajiname}, #scene=SOMETHING, ... string = string.replace(f"#scene={seen_scene}", "").strip()
# TODO: There's a few of these which I haven't handled yet. E.g.:
# %{pronoun} (before and after a dot, so careful with caps).
return string return string

View File

@ -5,10 +5,6 @@ from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
replace_special_tokens_in) replace_special_tokens_in)
from waifu.modules import BaseModule from waifu.modules import BaseModule
# TODO(11b): Figure out if we can do something better instead of hardcoding a
# fake name.
BOT_PREFIX = "Person 2"
class KajiwotoVDM(BaseModule): class KajiwotoVDM(BaseModule):
'''A Vanilla Dialogue Module powered by the Kajiwoto dataset.''' '''A Vanilla Dialogue Module powered by the Kajiwoto dataset.'''
@ -20,7 +16,8 @@ class KajiwotoVDM(BaseModule):
for turn in episode: for turn in episode:
turns.append( turns.append(
f"{PromptConstants.USER_PREFIX}: {turn.user_message}") f"{PromptConstants.USER_PREFIX}: {turn.user_message}")
turns.append(f"{BOT_PREFIX}: {turn.bot_response}") turns.append(
f"{PromptConstants.BOT_TOKEN}: {turn.bot_response}")
string = "\n".join(turns) string = "\n".join(turns)
processed_string = replace_special_tokens_in(string) processed_string = replace_special_tokens_in(string)