feat: improve handling of special tokens in the Kajiwoto dataset
This commit is contained in:
parent
b95b30cf88
commit
96b41dee60
|
@ -7,6 +7,9 @@ class PromptConstants:
|
||||||
# Token to be replaced with the user's display name within bot messages.
|
# Token to be replaced with the user's display name within bot messages.
|
||||||
USER_TOKEN = "<USER>"
|
USER_TOKEN = "<USER>"
|
||||||
|
|
||||||
|
# Token to be replaced by the bot's name.
|
||||||
|
BOT_TOKEN = "<BOT>"
|
||||||
|
|
||||||
# Global target word count. The word count is chosen in such a way that we
|
# Global target word count. The word count is chosen in such a way that we
|
||||||
# can fit all the required prompt trickery into the model's input, but still
|
# can fit all the required prompt trickery into the model's input, but still
|
||||||
# leave enough space for the user's input message and the infernce result.
|
# leave enough space for the user's input message and the infernce result.
|
||||||
|
|
|
@ -4,6 +4,7 @@ import os
|
||||||
import re
|
import re
|
||||||
import typing as t
|
import typing as t
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from waifu.core.consts import PromptConstants
|
||||||
|
|
||||||
from waifu.datasets import BaseDataset
|
from waifu.datasets import BaseDataset
|
||||||
from waifu.utils.dataset import get_data_path
|
from waifu.utils.dataset import get_data_path
|
||||||
|
@ -156,10 +157,13 @@ def replace_special_tokens_in(string: str) -> str:
|
||||||
Replaces known special tokens (e.g.: `%{name}`) with their expected
|
Replaces known special tokens (e.g.: `%{name}`) with their expected
|
||||||
equivalents.
|
equivalents.
|
||||||
'''
|
'''
|
||||||
|
string = string.replace("%{name}", PromptConstants.USER_TOKEN)
|
||||||
|
string = string.replace("%{kajiname}", PromptConstants.BOT_TOKEN)
|
||||||
|
|
||||||
if (match := re.search(KAJIWOTO_VARIANT_REGEX, string)) is not None:
|
if (match := re.search(KAJIWOTO_VARIANT_REGEX, string)) is not None:
|
||||||
special_token = match.groups()[0]
|
special_token = match.groups()[0]
|
||||||
if '|' not in special_token and special_token not in seen_special_tokens:
|
if '|' not in special_token and special_token not in seen_special_tokens:
|
||||||
logger.debug("Unhandled Kajiwoto token: %s", special_token)
|
logger.warning("Unhandled Kajiwoto token: %s", special_token)
|
||||||
seen_special_tokens.add(special_token)
|
seen_special_tokens.add(special_token)
|
||||||
|
|
||||||
if (scene_match := re.search(r"#scene=(.+?)\b", string)) is not None:
|
if (scene_match := re.search(r"#scene=(.+?)\b", string)) is not None:
|
||||||
|
@ -168,9 +172,12 @@ def replace_special_tokens_in(string: str) -> str:
|
||||||
logger.debug("Unhandled Kajiwoto scene: %s", seen_scene)
|
logger.debug("Unhandled Kajiwoto scene: %s", seen_scene)
|
||||||
seen_scenes.add(seen_scene)
|
seen_scenes.add(seen_scene)
|
||||||
|
|
||||||
# TODO: There's lots of these which I haven't handled at all. E.g.:
|
# Drop the scene marker. Maybe we can use it for something useful, but
|
||||||
# %{pronoun} (before and after a dot, so careful with caps), %{name},
|
# I can't think of anything at the moment.
|
||||||
# %{kajiname}, #scene=SOMETHING, ...
|
string = string.replace(f"#scene={seen_scene}", "").strip()
|
||||||
|
|
||||||
|
# TODO: There's a few of these which I haven't handled yet. E.g.:
|
||||||
|
# %{pronoun} (before and after a dot, so careful with caps).
|
||||||
return string
|
return string
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,10 +5,6 @@ from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
|
||||||
replace_special_tokens_in)
|
replace_special_tokens_in)
|
||||||
from waifu.modules import BaseModule
|
from waifu.modules import BaseModule
|
||||||
|
|
||||||
# TODO(11b): Figure out if we can do something better instead of hardcoding a
|
|
||||||
# fake name.
|
|
||||||
BOT_PREFIX = "Person 2"
|
|
||||||
|
|
||||||
|
|
||||||
class KajiwotoVDM(BaseModule):
|
class KajiwotoVDM(BaseModule):
|
||||||
'''A Vanilla Dialogue Module powered by the Kajiwoto dataset.'''
|
'''A Vanilla Dialogue Module powered by the Kajiwoto dataset.'''
|
||||||
|
@ -20,7 +16,8 @@ class KajiwotoVDM(BaseModule):
|
||||||
for turn in episode:
|
for turn in episode:
|
||||||
turns.append(
|
turns.append(
|
||||||
f"{PromptConstants.USER_PREFIX}: {turn.user_message}")
|
f"{PromptConstants.USER_PREFIX}: {turn.user_message}")
|
||||||
turns.append(f"{BOT_PREFIX}: {turn.bot_response}")
|
turns.append(
|
||||||
|
f"{PromptConstants.BOT_TOKEN}: {turn.bot_response}")
|
||||||
|
|
||||||
string = "\n".join(turns)
|
string = "\n".join(turns)
|
||||||
processed_string = replace_special_tokens_in(string)
|
processed_string = replace_special_tokens_in(string)
|
||||||
|
|
Loading…
Reference in New Issue