chore: fix linter/style problems
This commit is contained in:
parent
e4594338d2
commit
53494a6567
|
@ -4,8 +4,8 @@ import os
|
|||
import re
|
||||
import typing as t
|
||||
from dataclasses import dataclass
|
||||
from waifu.core.consts import PromptConstants
|
||||
|
||||
from waifu.core.consts import PromptConstants
|
||||
from waifu.datasets import BaseDataset
|
||||
from waifu.utils.dataset import get_data_path
|
||||
|
||||
|
|
|
@ -143,8 +143,8 @@ def _calculate_similarity_scores(bot_turns: list[str]) -> t.Any:
|
|||
This is a roundabout way to try and _possibly_ detect the post-1.1 CAI
|
||||
looping behavior so we can handle it during the data preprocessing.
|
||||
'''
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from sklearn.feature_extraction.text import CountVectorizer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
|
||||
vectorizer = CountVectorizer()
|
||||
x = vectorizer.fit_transform(bot_turns)
|
||||
|
|
|
@ -4,7 +4,6 @@ from waifu.core.consts import PromptConstants
|
|||
from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
|
||||
replace_special_tokens_in)
|
||||
from waifu.modules import BaseModule
|
||||
from waifu.modules.kajiwoto_vdm import BOT_PREFIX
|
||||
from waifu.utils.strings import uppercase
|
||||
|
||||
|
||||
|
@ -29,7 +28,7 @@ class KajiwotoPDM(BaseModule):
|
|||
" ").replace(
|
||||
" ", " ")
|
||||
turns.append(
|
||||
f"{PromptConstants.pdm_prefix_for(BOT_PREFIX)}: {description_string}\n{persona_string}"
|
||||
f"{PromptConstants.pdm_prefix_for(PromptConstants.BOT_TOKEN)}: {description_string}\n{persona_string}"
|
||||
)
|
||||
|
||||
# Empty turn to have a line break separating description/persona
|
||||
|
@ -39,7 +38,8 @@ class KajiwotoPDM(BaseModule):
|
|||
for turn in episode:
|
||||
turns.append(
|
||||
f"{PromptConstants.USER_PREFIX}: {turn.user_message}")
|
||||
turns.append(f"{BOT_PREFIX}: {turn.bot_response}")
|
||||
turns.append(
|
||||
f"{PromptConstants.BOT_TOKEN}: {turn.bot_response}")
|
||||
|
||||
string = "\n".join(turns)
|
||||
processed_string = replace_special_tokens_in(string)
|
||||
|
|
|
@ -48,8 +48,7 @@ def main() -> None:
|
|||
"-s",
|
||||
"--skip",
|
||||
type=int,
|
||||
help="If given, skip over this many episodes before printing."
|
||||
)
|
||||
help="If given, skip over this many episodes before printing.")
|
||||
|
||||
parser.add_argument("-v",
|
||||
"--verbose",
|
||||
|
|
Loading…
Reference in New Issue