diff --git a/waifu/datasets/characterai.py b/waifu/datasets/characterai.py index b6d4048..805788d 100644 --- a/waifu/datasets/characterai.py +++ b/waifu/datasets/characterai.py @@ -28,10 +28,16 @@ class CaiBotInfo: # they'll be of much use. +@dataclass(frozen=True) +class CaiMessage: + is_human: bool + text: str + + @dataclass(frozen=True) class CaiChat: - # First message is the bot's greeting, the one afterwards is the user. - messages: t.List[str] + # First message is always the bot's greeting. + messages: list[CaiMessage] bot: CaiBotInfo @@ -119,11 +125,15 @@ def _bot_info_from_dict(info_dict: dict[str, t.Any]) -> CaiBotInfo: ) -def _messages_from_dict(msgs_dict: list[dict[str, t.Any]]) -> list[str]: +def _messages_from_dict(msgs_dict: list[dict[str, t.Any]]) -> list[CaiMessage]: '''Builds an array of messages from an entry from the `histories` JSON.''' - messages: list[str] = [] + messages: list[CaiMessage] = [] for raw_message in msgs_dict: - messages.append(raw_message["text"]) + message = CaiMessage( + text=raw_message["text"], + is_human=raw_message["src"]["is_human"], + ) + messages.append(message) return messages diff --git a/waifu/modules/characterai_pdm.py b/waifu/modules/characterai_pdm.py index fb41ee4..5eae512 100644 --- a/waifu/modules/characterai_pdm.py +++ b/waifu/modules/characterai_pdm.py @@ -62,14 +62,12 @@ class CharacterAiPDM(BaseModule): # Now, start adding messages and break episodes apart if they get # too big. turns = base_turns.copy() - for idx, raw_message in enumerate(chat.messages): - # First message is always the bot (since it must send a - # greeting), and next up is always the user. - if idx % 2 == 0: - # TODO(11b): Handle `[NAME_IN_MESSAGE_REDACTED]`. - message = f"{chat.bot.name}: {raw_message}" + for raw_message in chat.messages: + if raw_message.is_human: + message = f"{PromptConstants.USER_PREFIX}: {raw_message.text}" else: - message = f"{PromptConstants.USER_PREFIX}: {raw_message}" + # TODO(11b): Handle `[NAME_IN_MESSAGE_REDACTED]`. + message = f"{chat.bot.name}: {raw_message.text}" turns.append(message) # Splitting logic.