From 53494a65678c9fc1170f22ba2ce5e981cfa986b6 Mon Sep 17 00:00:00 2001 From: 0x000011b <0x000011b@waifu.club> Date: Sun, 1 Jan 2023 11:50:23 -0300 Subject: [PATCH] chore: fix linter/style problems --- waifu/datasets/kajiwoto.py | 2 +- waifu/modules/characterai_pdm.py | 2 +- waifu/modules/kajiwoto_pdm.py | 6 +++--- waifu/scripts/build_dataset.py | 3 +-- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/waifu/datasets/kajiwoto.py b/waifu/datasets/kajiwoto.py index 039573e..d310bbd 100644 --- a/waifu/datasets/kajiwoto.py +++ b/waifu/datasets/kajiwoto.py @@ -4,8 +4,8 @@ import os import re import typing as t from dataclasses import dataclass -from waifu.core.consts import PromptConstants +from waifu.core.consts import PromptConstants from waifu.datasets import BaseDataset from waifu.utils.dataset import get_data_path diff --git a/waifu/modules/characterai_pdm.py b/waifu/modules/characterai_pdm.py index 37ed0ca..b0b5b3f 100644 --- a/waifu/modules/characterai_pdm.py +++ b/waifu/modules/characterai_pdm.py @@ -143,8 +143,8 @@ def _calculate_similarity_scores(bot_turns: list[str]) -> t.Any: This is a roundabout way to try and _possibly_ detect the post-1.1 CAI looping behavior so we can handle it during the data preprocessing. ''' - from sklearn.metrics.pairwise import cosine_similarity from sklearn.feature_extraction.text import CountVectorizer + from sklearn.metrics.pairwise import cosine_similarity vectorizer = CountVectorizer() x = vectorizer.fit_transform(bot_turns) diff --git a/waifu/modules/kajiwoto_pdm.py b/waifu/modules/kajiwoto_pdm.py index 99fd6ef..4b8d1b4 100644 --- a/waifu/modules/kajiwoto_pdm.py +++ b/waifu/modules/kajiwoto_pdm.py @@ -4,7 +4,6 @@ from waifu.core.consts import PromptConstants from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for, replace_special_tokens_in) from waifu.modules import BaseModule -from waifu.modules.kajiwoto_vdm import BOT_PREFIX from waifu.utils.strings import uppercase @@ -29,7 +28,7 @@ class KajiwotoPDM(BaseModule): " ").replace( " ", " ") turns.append( - f"{PromptConstants.pdm_prefix_for(BOT_PREFIX)}: {description_string}\n{persona_string}" + f"{PromptConstants.pdm_prefix_for(PromptConstants.BOT_TOKEN)}: {description_string}\n{persona_string}" ) # Empty turn to have a line break separating description/persona @@ -39,7 +38,8 @@ class KajiwotoPDM(BaseModule): for turn in episode: turns.append( f"{PromptConstants.USER_PREFIX}: {turn.user_message}") - turns.append(f"{BOT_PREFIX}: {turn.bot_response}") + turns.append( + f"{PromptConstants.BOT_TOKEN}: {turn.bot_response}") string = "\n".join(turns) processed_string = replace_special_tokens_in(string) diff --git a/waifu/scripts/build_dataset.py b/waifu/scripts/build_dataset.py index cf9a22a..8678da6 100755 --- a/waifu/scripts/build_dataset.py +++ b/waifu/scripts/build_dataset.py @@ -48,8 +48,7 @@ def main() -> None: "-s", "--skip", type=int, - help="If given, skip over this many episodes before printing." - ) + help="If given, skip over this many episodes before printing.") parser.add_argument("-v", "--verbose",