feat: update inference code for pythia/cai data-based models
This commit is contained in:
parent
3bfb623f26
commit
186df60691
|
@ -61,7 +61,10 @@ def _build_blacklist_for(bot_name: str) -> list[str]:
|
||||||
# out to the fine-tuned models as well.
|
# out to the fine-tuned models as well.
|
||||||
bad_opt_generations = ["___", "____", "_____"]
|
bad_opt_generations = ["___", "____", "_____"]
|
||||||
|
|
||||||
return [pdm_prefix, *bad_opt_generations]
|
# And Pythia likes to do this.
|
||||||
|
bad_pythia_generations = ["...."]
|
||||||
|
|
||||||
|
return [pdm_prefix, *bad_opt_generations, *bad_pythia_generations]
|
||||||
|
|
||||||
|
|
||||||
def _build_model_and_tokenizer_for(args: argparse.Namespace,
|
def _build_model_and_tokenizer_for(args: argparse.Namespace,
|
||||||
|
@ -80,9 +83,12 @@ def _build_model_and_tokenizer_for(args: argparse.Namespace,
|
||||||
state_dict = torch.load(args.checkpoint,
|
state_dict = torch.load(args.checkpoint,
|
||||||
map_location="cuda").pop("model")
|
map_location="cuda").pop("model")
|
||||||
|
|
||||||
|
tokenizer_kwargs = {"add_special_tokens": False}
|
||||||
|
if "facebook/opt-" in args.model_name:
|
||||||
|
tokenizer_kwargs["add_prefix_space"] = True
|
||||||
|
|
||||||
bad_words_ids = [
|
bad_words_ids = [
|
||||||
tokenizer(bad_word, add_prefix_space=True,
|
tokenizer(bad_word, **tokenizer_kwargs).input_ids
|
||||||
add_special_tokens=False).input_ids
|
|
||||||
for bad_word in _build_blacklist_for(bot_name)
|
for bad_word in _build_blacklist_for(bot_name)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -104,7 +110,7 @@ def _run_raw_inference(model: t.Any, tokenizer: t.Any, prompt: str,
|
||||||
logits = model.generate(
|
logits = model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
max_new_tokens=3,
|
max_new_tokens=32,
|
||||||
top_k=50,
|
top_k=50,
|
||||||
top_p=0.90,
|
top_p=0.90,
|
||||||
)
|
)
|
||||||
|
@ -113,39 +119,14 @@ def _run_raw_inference(model: t.Any, tokenizer: t.Any, prompt: str,
|
||||||
# Then, contrastive search.
|
# Then, contrastive search.
|
||||||
input_ids = tokenizer(output, return_tensors="pt").input_ids.to("cuda")
|
input_ids = tokenizer(output, return_tensors="pt").input_ids.to("cuda")
|
||||||
logits = model.generate(input_ids,
|
logits = model.generate(input_ids,
|
||||||
max_new_tokens=128,
|
max_new_tokens=96,
|
||||||
penalty_alpha=0.6,
|
penalty_alpha=0.6,
|
||||||
top_k=6)
|
top_k=6)
|
||||||
|
|
||||||
# FIXME(11b): All of these break in different ways. Write a more robust
|
# Then, we trim out the input prompt from the generated output.
|
||||||
# solution.
|
|
||||||
USE_DUMB_TRIMMING_ALGORITHM = False
|
|
||||||
if USE_DUMB_TRIMMING_ALGORITHM:
|
|
||||||
output = tokenizer.decode(logits[0], skip_special_tokens=True)
|
|
||||||
trimmed_output = output.replace(prompt, "").strip()
|
|
||||||
|
|
||||||
# Set a breakpoint for when trimming goes wrong, so we can investigate.
|
|
||||||
if len(trimmed_output) >= len(output):
|
|
||||||
import pdb
|
|
||||||
pdb.set_trace()
|
|
||||||
|
|
||||||
return trimmed_output
|
|
||||||
|
|
||||||
USE_SLICING_TRIMMING_ALGORITHM = False
|
|
||||||
if USE_SLICING_TRIMMING_ALGORITHM:
|
|
||||||
logger.debug("Untrimmed inference output is: `%s`",
|
|
||||||
tokenizer.decode(logits[0], skip_special_tokens=True))
|
|
||||||
|
|
||||||
# Slicing logic taken from:
|
|
||||||
# https://github.com/huggingface/transformers/issues/17117#issuecomment-1124497554
|
|
||||||
logits_without_input_prompt = logits[:, input_ids.shape[1]:]
|
|
||||||
output = tokenizer.decode(logits_without_input_prompt[0],
|
|
||||||
skip_special_tokens=True)
|
|
||||||
return output
|
|
||||||
|
|
||||||
output = tokenizer.decode(logits[0], skip_special_tokens=True)
|
output = tokenizer.decode(logits[0], skip_special_tokens=True)
|
||||||
if (idx := prompt.rfind(user_message)) != -1:
|
if (idx := prompt.rfind(user_message)) != -1:
|
||||||
trimmed_output = output[idx + len(user_message):]
|
trimmed_output = output[idx + len(user_message):].strip()
|
||||||
return trimmed_output
|
return trimmed_output
|
||||||
else:
|
else:
|
||||||
raise ValueError("Couldn't find user message in the prompt. What?")
|
raise ValueError("Couldn't find user message in the prompt. What?")
|
||||||
|
@ -175,9 +156,9 @@ def _parse_messages_from_str(string: str, names: list[str]) -> list[str]:
|
||||||
for match in speaker_regex.finditer(string):
|
for match in speaker_regex.finditer(string):
|
||||||
message_start_indexes.append(match.start())
|
message_start_indexes.append(match.start())
|
||||||
|
|
||||||
if len(message_start_indexes) == 0:
|
if len(message_start_indexes) < 2:
|
||||||
# Single message in the string, so no message separators to match.
|
# Single message in the string.
|
||||||
return [string]
|
return [string.strip()]
|
||||||
|
|
||||||
prev_start_idx = message_start_indexes[0]
|
prev_start_idx = message_start_indexes[0]
|
||||||
messages = []
|
messages = []
|
||||||
|
@ -208,25 +189,15 @@ def _gr_run_inference(model: t.Any, tokenizer: t.Any, context: str,
|
||||||
prompt = "\n".join(
|
prompt = "\n".join(
|
||||||
[context, "", *history, f"You: {user_message}", f"{character_name}: "])
|
[context, "", *history, f"You: {user_message}", f"{character_name}: "])
|
||||||
|
|
||||||
raw_output = _run_raw_inference(model, tokenizer, prompt, user_message)
|
output = _run_raw_inference(model, tokenizer, prompt, user_message).strip()
|
||||||
logger.debug("After inference, `raw_output` is: `%s`", raw_output)
|
logger.debug("_run_raw_inference returned `%s` after .strip()", output)
|
||||||
|
|
||||||
# So there's a bit of a shitty bug here. The tensor slicing logic inside of
|
|
||||||
# `_run_raw_inference` doesn't always slice off the input prompt cleanly,
|
|
||||||
# sometimes it leaves a little bit of it in the beginning of the output. To
|
|
||||||
# work around that, we look for a ":" close to the beginning of the output
|
|
||||||
# string, and if we find it, we trim out everything that came before it.
|
|
||||||
STOP_SEARCHING_AT_IDX = 8
|
|
||||||
if (idx := raw_output.find(":", 0, STOP_SEARCHING_AT_IDX)) != -1:
|
|
||||||
raw_output = raw_output[idx + 1:]
|
|
||||||
|
|
||||||
output = f"{character_name}: {raw_output.strip()}"
|
|
||||||
|
|
||||||
# If there's enough space, the model will likely generate more than just its
|
# If there's enough space, the model will likely generate more than just its
|
||||||
# own message, so we need to trim that out and just remove the first
|
# own message, so we need to trim that out and just remove the first
|
||||||
# generated message.
|
# generated message.
|
||||||
generated_messages = _parse_messages_from_str(output,
|
generated_messages = _parse_messages_from_str(output,
|
||||||
["You", character_name])
|
["You", character_name])
|
||||||
|
logger.debug("Generated messages is `%s`", generated_messages)
|
||||||
bot_message = generated_messages[0]
|
bot_message = generated_messages[0]
|
||||||
|
|
||||||
logger.info("Generated message: `%s`", bot_message)
|
logger.info("Generated message: `%s`", bot_message)
|
||||||
|
@ -287,8 +258,8 @@ def _build_gradio_ui_for(model: t.Any, tokenizer: t.Any) -> t.Any:
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
submit_btn = gr.Button("Submit input")
|
submit_btn = gr.Button("Submit input")
|
||||||
submit_fn = lambda context, history, character_name, user_message: _gr_run_inference(
|
submit_fn = lambda context, history, character_name, user_message: _gr_run_inference(
|
||||||
model, tokenizer, context, history, character_name,
|
model, tokenizer, context, history, character_name, user_message
|
||||||
user_message)
|
)
|
||||||
submit_btn.click(
|
submit_btn.click(
|
||||||
fn=submit_fn,
|
fn=submit_fn,
|
||||||
inputs=[context, history, character_name, user_message],
|
inputs=[context, history, character_name, user_message],
|
||||||
|
@ -296,8 +267,8 @@ def _build_gradio_ui_for(model: t.Any, tokenizer: t.Any) -> t.Any:
|
||||||
|
|
||||||
regenerate_btn = gr.Button("Regenerate last output")
|
regenerate_btn = gr.Button("Regenerate last output")
|
||||||
regenerate_fn = lambda context, history, character_name, user_message: _gr_regenerate_last_output(
|
regenerate_fn = lambda context, history, character_name, user_message: _gr_regenerate_last_output(
|
||||||
model, tokenizer, context, history, character_name,
|
model, tokenizer, context, history, character_name, user_message
|
||||||
user_message)
|
)
|
||||||
regenerate_btn.click(
|
regenerate_btn.click(
|
||||||
fn=regenerate_fn,
|
fn=regenerate_fn,
|
||||||
inputs=[context, history, character_name, user_message],
|
inputs=[context, history, character_name, user_message],
|
||||||
|
@ -305,8 +276,8 @@ def _build_gradio_ui_for(model: t.Any, tokenizer: t.Any) -> t.Any:
|
||||||
|
|
||||||
undo_btn = gr.Button("Undo last exchange")
|
undo_btn = gr.Button("Undo last exchange")
|
||||||
undo_btn.click(fn=_gr_undo,
|
undo_btn.click(fn=_gr_undo,
|
||||||
inputs=[history],
|
inputs=[history],
|
||||||
outputs=[history, history_text])
|
outputs=[history, history_text])
|
||||||
|
|
||||||
return interface
|
return interface
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue