diff --git a/waifu/core/consts.py b/waifu/core/consts.py index 9594a05..2090dfa 100644 --- a/waifu/core/consts.py +++ b/waifu/core/consts.py @@ -4,6 +4,9 @@ class PromptConstants: # Prefix for user messages. USER_PREFIX = "You" + # Token to be replaced with the user's display name within bot messages. + USER_TOKEN = "" + # Global target word count. The word count is chosen in such a way that we # can fit all the required prompt trickery into the model's input, but still # leave enough space for the user's input message and the infernce result. diff --git a/waifu/modules/characterai_pdm.py b/waifu/modules/characterai_pdm.py index 5eae512..272b2c5 100644 --- a/waifu/modules/characterai_pdm.py +++ b/waifu/modules/characterai_pdm.py @@ -63,19 +63,20 @@ class CharacterAiPDM(BaseModule): # too big. turns = base_turns.copy() for raw_message in chat.messages: + message_text = _process_message(raw_message.text) if raw_message.is_human: - message = f"{PromptConstants.USER_PREFIX}: {raw_message.text}" + message = f"{PromptConstants.USER_PREFIX}: {message_text}" else: - # TODO(11b): Handle `[NAME_IN_MESSAGE_REDACTED]`. - message = f"{chat.bot.name}: {raw_message.text}" + message = f"{chat.bot.name}: {message_text}" turns.append(message) # Splitting logic. cur_episode_len = sum([len(x.split()) for x in turns]) if cur_episode_len > PromptConstants.TARGET_WORD_COUNT_PER_EPISODE: logger.debug( - "Episode length went over TARGET_WORD_COUNT_PER_EPISODE, breaking apart." - ) + "Episode length went over TARGET_WORD_COUNT_PER_EPISODE (%s > %s), breaking apart.", + cur_episode_len, + PromptConstants.TARGET_WORD_COUNT_PER_EPISODE) # Adding the last message made the episode go over the # target word count, so we return the episode without it... @@ -95,6 +96,18 @@ class CharacterAiPDM(BaseModule): EXAMPLE_CHAT_REGEX = re.compile( r"({{char}}|{{random_user_\d}}): (.+?)(?:END_OF_DIALOG)", re.DOTALL) RELAXED_EXAMPLE_CHAT_REGEX = re.compile(r"{{char}}: .+", re.DOTALL) +EXCESSIVE_ELLIPSIS_REGEX = re.compile(r"\.{4,}") + + +def _process_message(original_string: str) -> str: + ''' + Processes a single message to clean it up and filter/replace the appropriate + special tokens. + ''' + string = EXCESSIVE_ELLIPSIS_REGEX.sub("...", original_string) + string = string.replace("[NAME_IN_MESSAGE_REDACTED]", + PromptConstants.USER_TOKEN) + return string.strip() def _parse_definitions_for(bot_name: str,