chore: fix linter/style problems

This commit is contained in:
11b 2023-01-01 11:50:23 -03:00
parent e4594338d2
commit 53494a6567
4 changed files with 6 additions and 7 deletions

View File

@ -4,8 +4,8 @@ import os
import re import re
import typing as t import typing as t
from dataclasses import dataclass from dataclasses import dataclass
from waifu.core.consts import PromptConstants
from waifu.core.consts import PromptConstants
from waifu.datasets import BaseDataset from waifu.datasets import BaseDataset
from waifu.utils.dataset import get_data_path from waifu.utils.dataset import get_data_path

View File

@ -143,8 +143,8 @@ def _calculate_similarity_scores(bot_turns: list[str]) -> t.Any:
This is a roundabout way to try and _possibly_ detect the post-1.1 CAI This is a roundabout way to try and _possibly_ detect the post-1.1 CAI
looping behavior so we can handle it during the data preprocessing. looping behavior so we can handle it during the data preprocessing.
''' '''
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import CountVectorizer from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
vectorizer = CountVectorizer() vectorizer = CountVectorizer()
x = vectorizer.fit_transform(bot_turns) x = vectorizer.fit_transform(bot_turns)

View File

@ -4,7 +4,6 @@ from waifu.core.consts import PromptConstants
from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for, from waifu.datasets.kajiwoto import (KajiwotoDataset, generate_variants_for,
replace_special_tokens_in) replace_special_tokens_in)
from waifu.modules import BaseModule from waifu.modules import BaseModule
from waifu.modules.kajiwoto_vdm import BOT_PREFIX
from waifu.utils.strings import uppercase from waifu.utils.strings import uppercase
@ -29,7 +28,7 @@ class KajiwotoPDM(BaseModule):
" ").replace( " ").replace(
" ", " ") " ", " ")
turns.append( turns.append(
f"{PromptConstants.pdm_prefix_for(BOT_PREFIX)}: {description_string}\n{persona_string}" f"{PromptConstants.pdm_prefix_for(PromptConstants.BOT_TOKEN)}: {description_string}\n{persona_string}"
) )
# Empty turn to have a line break separating description/persona # Empty turn to have a line break separating description/persona
@ -39,7 +38,8 @@ class KajiwotoPDM(BaseModule):
for turn in episode: for turn in episode:
turns.append( turns.append(
f"{PromptConstants.USER_PREFIX}: {turn.user_message}") f"{PromptConstants.USER_PREFIX}: {turn.user_message}")
turns.append(f"{BOT_PREFIX}: {turn.bot_response}") turns.append(
f"{PromptConstants.BOT_TOKEN}: {turn.bot_response}")
string = "\n".join(turns) string = "\n".join(turns)
processed_string = replace_special_tokens_in(string) processed_string = replace_special_tokens_in(string)

View File

@ -48,8 +48,7 @@ def main() -> None:
"-s", "-s",
"--skip", "--skip",
type=int, type=int,
help="If given, skip over this many episodes before printing." help="If given, skip over this many episodes before printing.")
)
parser.add_argument("-v", parser.add_argument("-v",
"--verbose", "--verbose",