toolbox/toolbox/modules/characterai_pdm.py

204 lines
8.1 KiB
Python

import logging
import re
import typing as t
from toolbox.core.consts import PromptConstants
from toolbox.datasets.characterai import CharacterAiDataset
from toolbox.modules import BaseModule
logger = logging.getLogger(__name__)
# Discard episodes shorter than 3 turns. These are likely not very useful for
# the model to learn to converse properly, since they only really contain one
# dialogue response (the first turn is the hardcoded greeting, and the second is
# the user's input).
MIN_EPISODE_LEN = 3
# Discard episodes where the average similarity between the bot's messages is
# higher than this value.
EPISODE_SIMILARITY_THRESHOLD = 0.55
#
# So here's a quick rundown of what needs to happen. We have a limited context
# window (of 2048 tokens, ATM) and for the Persona Dialogue Module (PDM), we
# need to fit all of the following things in there:
#
# - The bot's description/definitions/persona/whatever you want to call it
# - Last X messages of chat history/context (the more the merrier, usually)
# - The user's input message, e.g. `You: [user text here]`
# - The bot's response, e.g. `[Bot name]: [space for the bot's response]`
#
# As such, most of the code here is about taking globs of text and
# chunking/splitting them up to make the format described above fit into blocks
# of 2048-ish tokens (not exactly 2048 because the tokenizer depends on the
# model used, and I don't want to create a dependency on a specific model at the
# data processing stage at this point).
#
class CharacterAiPDM(BaseModule):
'''A Persona Dialogue Module powered by CharacterAI data.'''
def generator(self) -> t.Generator[list[str], None, None]:
for chat in CharacterAiDataset():
if len(chat.messages) < MIN_EPISODE_LEN:
logger.debug(
"Found episode shorter than minimum length (%s < %s), discarding.",
len(chat.messages), MIN_EPISODE_LEN)
continue
base_turns = []
if chat.bot.description is not None:
pdm_prefix = PromptConstants.pdm_prefix_for(chat.bot.name)
pdm_string = f"{pdm_prefix}: {chat.bot.description}"
base_turns.append(pdm_string)
if chat.bot.definitions is not None:
parsed_definitions, parsed_examples = _parse_definitions_for(
chat.bot.name, chat.bot.definitions)
base_turns.append(parsed_definitions)
# Add turn to separate persona info from messages, if
# necessary.
if len(base_turns) > 0:
base_turns.append(PromptConstants.CHAT_START_TOKEN)
# Now, start adding messages and break episodes apart if they get
# too big.
turns = base_turns.copy()
bot_messages: list[str] = []
for raw_message in chat.messages:
message_text = _process_message(raw_message.text)
if raw_message.is_human:
message = f"{PromptConstants.USER_PREFIX}: {message_text}"
else:
message = f"{chat.bot.name}: {message_text}"
bot_messages.append(message_text)
turns.append(message)
# Splitting logic.
cur_episode_len = sum([len(x.split()) for x in turns])
if cur_episode_len > PromptConstants.TARGET_WORD_COUNT_PER_EPISODE:
logger.debug(
"Episode length went over TARGET_WORD_COUNT_PER_EPISODE (%s > %s), breaking apart.",
cur_episode_len,
PromptConstants.TARGET_WORD_COUNT_PER_EPISODE)
# Calculate similarity between sequential bot message pairs
# within this episode, and drop it if it goes above the
# defined threshold.
similarity_score_matrix = _calculate_similarity_scores(
bot_messages)
average_similarity_score_for_episode = 0.0
for score in similarity_score_matrix[0]:
if score == 1:
continue
average_similarity_score_for_episode += score
average_similarity_score_for_episode /= 2
# Adding the last message made the episode go over the
# target word count, so we return the episode without it...
removed_turn = turns.pop()
if average_similarity_score_for_episode <= EPISODE_SIMILARITY_THRESHOLD:
# yield "\n".join(turns)
yield turns
else:
logger.debug(
"Ignoring episode due to high similarity between messages (%s > %s)",
average_similarity_score_for_episode,
EPISODE_SIMILARITY_THRESHOLD)
# ...and start the next episode with the message we had to
# trim out from this one.
turns = base_turns.copy()
turns.append(removed_turn)
bot_messages = []
#
# Private helpers.
#
EXAMPLE_CHAT_REGEX = re.compile(
r"({{char}}|{{random_user_\d}}): (.+?)(?:END_OF_DIALOG)", re.DOTALL)
RELAXED_EXAMPLE_CHAT_REGEX = re.compile(r"{{char}}: .+", re.DOTALL)
EXCESSIVE_ELLIPSIS_REGEX = re.compile(r"\.{4,}")
def _process_message(original_string: str) -> str:
'''
Processes a single message to clean it up and filter/replace the appropriate
special tokens.
'''
string = EXCESSIVE_ELLIPSIS_REGEX.sub("...", original_string)
string = string.replace("[NAME_IN_MESSAGE_REDACTED]",
PromptConstants.USER_TOKEN)
return string.strip()
def _calculate_similarity_scores(bot_turns: list[str]) -> t.Any:
'''
Calculates similarity scores between bot turns.
This is a roundabout way to try and _possibly_ detect the post-1.1 CAI
looping behavior so we can handle it during the data preprocessing.
'''
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
vectorizer = CountVectorizer()
x = vectorizer.fit_transform(bot_turns)
arr = x.toarray()
sims = cosine_similarity(arr)
return sims
def _parse_definitions_for(bot_name: str,
raw_definitions: str) -> t.Tuple[str, list[str]]:
'''
Parses bot definitions.
This function attempts to find example messages within the input string,
parses them accordingly and returns them separately from the rest of the
text in the original `definitions` string.
'''
definitions, examples = _parse_definitions_strict(raw_definitions)
if len(examples) == 0:
definitions, examples = _parse_definitions_relaxed(raw_definitions)
parsed_definitions = definitions.replace("{{char}}", bot_name)
parsed_examples = [x.replace("{{char}}", bot_name) for x in examples]
return parsed_definitions, parsed_examples
def _parse_definitions_strict(definitions: str) -> t.Tuple[str, list[str]]:
'''
Strict parsing of a bot's definitions string, assumes END_OF_DIALOG was used
correctly by the bot's creator.
'''
matched_example_chats = EXAMPLE_CHAT_REGEX.finditer(definitions)
examples = [
x.group().replace("END_OF_DIALOG", "").strip()
for x in matched_example_chats
]
definitions_without_examples = re.sub(EXAMPLE_CHAT_REGEX, "", definitions)
return definitions_without_examples, examples
def _parse_definitions_relaxed(definitions: str) -> t.Tuple[str, list[str]]:
'''
Same as the `_parse_definitions_strict`, but this one is much more relaxed
and should be used for when the bot creator didn't properly use
END_OF_DIALOG to delineate example chats.
'''
matched_example_chats = RELAXED_EXAMPLE_CHAT_REGEX.finditer(definitions)
examples = [x.group().strip() for x in matched_example_chats]
definitions_without_examples = re.sub(RELAXED_EXAMPLE_CHAT_REGEX, "",
definitions)
return definitions_without_examples, examples