fix: human/bot messages being incorrectly labeled as eachother
This commit is contained in:
parent
5b26097905
commit
3bfb623f26
|
@ -28,10 +28,16 @@ class CaiBotInfo:
|
|||
# they'll be of much use.
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CaiMessage:
|
||||
is_human: bool
|
||||
text: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CaiChat:
|
||||
# First message is the bot's greeting, the one afterwards is the user.
|
||||
messages: t.List[str]
|
||||
# First message is always the bot's greeting.
|
||||
messages: list[CaiMessage]
|
||||
bot: CaiBotInfo
|
||||
|
||||
|
||||
|
@ -119,11 +125,15 @@ def _bot_info_from_dict(info_dict: dict[str, t.Any]) -> CaiBotInfo:
|
|||
)
|
||||
|
||||
|
||||
def _messages_from_dict(msgs_dict: list[dict[str, t.Any]]) -> list[str]:
|
||||
def _messages_from_dict(msgs_dict: list[dict[str, t.Any]]) -> list[CaiMessage]:
|
||||
'''Builds an array of messages from an entry from the `histories` JSON.'''
|
||||
messages: list[str] = []
|
||||
messages: list[CaiMessage] = []
|
||||
for raw_message in msgs_dict:
|
||||
messages.append(raw_message["text"])
|
||||
message = CaiMessage(
|
||||
text=raw_message["text"],
|
||||
is_human=raw_message["src"]["is_human"],
|
||||
)
|
||||
messages.append(message)
|
||||
return messages
|
||||
|
||||
|
||||
|
|
|
@ -62,14 +62,12 @@ class CharacterAiPDM(BaseModule):
|
|||
# Now, start adding messages and break episodes apart if they get
|
||||
# too big.
|
||||
turns = base_turns.copy()
|
||||
for idx, raw_message in enumerate(chat.messages):
|
||||
# First message is always the bot (since it must send a
|
||||
# greeting), and next up is always the user.
|
||||
if idx % 2 == 0:
|
||||
# TODO(11b): Handle `[NAME_IN_MESSAGE_REDACTED]`.
|
||||
message = f"{chat.bot.name}: {raw_message}"
|
||||
for raw_message in chat.messages:
|
||||
if raw_message.is_human:
|
||||
message = f"{PromptConstants.USER_PREFIX}: {raw_message.text}"
|
||||
else:
|
||||
message = f"{PromptConstants.USER_PREFIX}: {raw_message}"
|
||||
# TODO(11b): Handle `[NAME_IN_MESSAGE_REDACTED]`.
|
||||
message = f"{chat.bot.name}: {raw_message.text}"
|
||||
turns.append(message)
|
||||
|
||||
# Splitting logic.
|
||||
|
|
Loading…
Reference in New Issue