fix: human/bot messages being incorrectly labeled as eachother

This commit is contained in:
11b 2022-12-24 17:58:33 -03:00
parent 5b26097905
commit 3bfb623f26
2 changed files with 20 additions and 12 deletions

View File

@ -28,10 +28,16 @@ class CaiBotInfo:
# they'll be of much use.
@dataclass(frozen=True)
class CaiMessage:
is_human: bool
text: str
@dataclass(frozen=True)
class CaiChat:
# First message is the bot's greeting, the one afterwards is the user.
messages: t.List[str]
# First message is always the bot's greeting.
messages: list[CaiMessage]
bot: CaiBotInfo
@ -119,11 +125,15 @@ def _bot_info_from_dict(info_dict: dict[str, t.Any]) -> CaiBotInfo:
)
def _messages_from_dict(msgs_dict: list[dict[str, t.Any]]) -> list[str]:
def _messages_from_dict(msgs_dict: list[dict[str, t.Any]]) -> list[CaiMessage]:
'''Builds an array of messages from an entry from the `histories` JSON.'''
messages: list[str] = []
messages: list[CaiMessage] = []
for raw_message in msgs_dict:
messages.append(raw_message["text"])
message = CaiMessage(
text=raw_message["text"],
is_human=raw_message["src"]["is_human"],
)
messages.append(message)
return messages

View File

@ -62,14 +62,12 @@ class CharacterAiPDM(BaseModule):
# Now, start adding messages and break episodes apart if they get
# too big.
turns = base_turns.copy()
for idx, raw_message in enumerate(chat.messages):
# First message is always the bot (since it must send a
# greeting), and next up is always the user.
if idx % 2 == 0:
# TODO(11b): Handle `[NAME_IN_MESSAGE_REDACTED]`.
message = f"{chat.bot.name}: {raw_message}"
for raw_message in chat.messages:
if raw_message.is_human:
message = f"{PromptConstants.USER_PREFIX}: {raw_message.text}"
else:
message = f"{PromptConstants.USER_PREFIX}: {raw_message}"
# TODO(11b): Handle `[NAME_IN_MESSAGE_REDACTED]`.
message = f"{chat.bot.name}: {raw_message.text}"
turns.append(message)
# Splitting logic.