chore: fix linter/style problems

This commit is contained in:
11b 2023-01-01 11:50:23 -03:00
parent e4594338d2
commit 53494a6567
4 changed files with 6 additions and 7 deletions

View File

@ -4,8 +4,8 @@ import os
import re
import typing as t
from dataclasses import dataclass
from waifu.core.consts import PromptConstants
from waifu.core.consts import PromptConstants
from waifu.datasets import BaseDataset
from waifu.utils.dataset import get_data_path

View File

@ -143,8 +143,8 @@ def _calculate_similarity_scores(bot_turns: list[str]) -> t.Any:
This is a roundabout way to try and _possibly_ detect the post-1.1 CAI
looping behavior so we can handle it during the data preprocessing.
'''
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
vectorizer = CountVectorizer()
x = vectorizer.fit_transform(bot_turns)

View File

@ -4,7 +4,6 @@ from waifu.core.consts import PromptConstants
from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
replace_special_tokens_in)
from waifu.modules import BaseModule
from waifu.modules.kajiwoto_vdm import BOT_PREFIX
from waifu.utils.strings import uppercase
@ -29,7 +28,7 @@ class KajiwotoPDM(BaseModule):
" ").replace(
" ", " ")
turns.append(
f"{PromptConstants.pdm_prefix_for(BOT_PREFIX)}: {description_string}\n{persona_string}"
f"{PromptConstants.pdm_prefix_for(PromptConstants.BOT_TOKEN)}: {description_string}\n{persona_string}"
)
# Empty turn to have a line break separating description/persona
@ -39,7 +38,8 @@ class KajiwotoPDM(BaseModule):
for turn in episode:
turns.append(
f"{PromptConstants.USER_PREFIX}: {turn.user_message}")
turns.append(f"{BOT_PREFIX}: {turn.bot_response}")
turns.append(
f"{PromptConstants.BOT_TOKEN}: {turn.bot_response}")
string = "\n".join(turns)
processed_string = replace_special_tokens_in(string)

View File

@ -48,8 +48,7 @@ def main() -> None:
"-s",
"--skip",
type=int,
help="If given, skip over this many episodes before printing."
)
help="If given, skip over this many episodes before printing.")
parser.add_argument("-v",
"--verbose",