diff --git a/training/inference.py b/training/inference.py index 3f26e20..a44a1f5 100755 --- a/training/inference.py +++ b/training/inference.py @@ -61,7 +61,10 @@ def _build_blacklist_for(bot_name: str) -> list[str]: # out to the fine-tuned models as well. bad_opt_generations = ["___", "____", "_____"] - return [pdm_prefix, *bad_opt_generations] + # And Pythia likes to do this. + bad_pythia_generations = ["...."] + + return [pdm_prefix, *bad_opt_generations, *bad_pythia_generations] def _build_model_and_tokenizer_for(args: argparse.Namespace, @@ -80,9 +83,12 @@ def _build_model_and_tokenizer_for(args: argparse.Namespace, state_dict = torch.load(args.checkpoint, map_location="cuda").pop("model") + tokenizer_kwargs = {"add_special_tokens": False} + if "facebook/opt-" in args.model_name: + tokenizer_kwargs["add_prefix_space"] = True + bad_words_ids = [ - tokenizer(bad_word, add_prefix_space=True, - add_special_tokens=False).input_ids + tokenizer(bad_word, **tokenizer_kwargs).input_ids for bad_word in _build_blacklist_for(bot_name) ] @@ -104,7 +110,7 @@ def _run_raw_inference(model: t.Any, tokenizer: t.Any, prompt: str, logits = model.generate( input_ids, do_sample=True, - max_new_tokens=3, + max_new_tokens=32, top_k=50, top_p=0.90, ) @@ -113,39 +119,14 @@ def _run_raw_inference(model: t.Any, tokenizer: t.Any, prompt: str, # Then, contrastive search. input_ids = tokenizer(output, return_tensors="pt").input_ids.to("cuda") logits = model.generate(input_ids, - max_new_tokens=128, + max_new_tokens=96, penalty_alpha=0.6, top_k=6) - # FIXME(11b): All of these break in different ways. Write a more robust - # solution. - USE_DUMB_TRIMMING_ALGORITHM = False - if USE_DUMB_TRIMMING_ALGORITHM: - output = tokenizer.decode(logits[0], skip_special_tokens=True) - trimmed_output = output.replace(prompt, "").strip() - - # Set a breakpoint for when trimming goes wrong, so we can investigate. - if len(trimmed_output) >= len(output): - import pdb - pdb.set_trace() - - return trimmed_output - - USE_SLICING_TRIMMING_ALGORITHM = False - if USE_SLICING_TRIMMING_ALGORITHM: - logger.debug("Untrimmed inference output is: `%s`", - tokenizer.decode(logits[0], skip_special_tokens=True)) - - # Slicing logic taken from: - # https://github.com/huggingface/transformers/issues/17117#issuecomment-1124497554 - logits_without_input_prompt = logits[:, input_ids.shape[1]:] - output = tokenizer.decode(logits_without_input_prompt[0], - skip_special_tokens=True) - return output - + # Then, we trim out the input prompt from the generated output. output = tokenizer.decode(logits[0], skip_special_tokens=True) if (idx := prompt.rfind(user_message)) != -1: - trimmed_output = output[idx + len(user_message):] + trimmed_output = output[idx + len(user_message):].strip() return trimmed_output else: raise ValueError("Couldn't find user message in the prompt. What?") @@ -175,9 +156,9 @@ def _parse_messages_from_str(string: str, names: list[str]) -> list[str]: for match in speaker_regex.finditer(string): message_start_indexes.append(match.start()) - if len(message_start_indexes) == 0: - # Single message in the string, so no message separators to match. - return [string] + if len(message_start_indexes) < 2: + # Single message in the string. + return [string.strip()] prev_start_idx = message_start_indexes[0] messages = [] @@ -208,25 +189,15 @@ def _gr_run_inference(model: t.Any, tokenizer: t.Any, context: str, prompt = "\n".join( [context, "", *history, f"You: {user_message}", f"{character_name}: "]) - raw_output = _run_raw_inference(model, tokenizer, prompt, user_message) - logger.debug("After inference, `raw_output` is: `%s`", raw_output) - - # So there's a bit of a shitty bug here. The tensor slicing logic inside of - # `_run_raw_inference` doesn't always slice off the input prompt cleanly, - # sometimes it leaves a little bit of it in the beginning of the output. To - # work around that, we look for a ":" close to the beginning of the output - # string, and if we find it, we trim out everything that came before it. - STOP_SEARCHING_AT_IDX = 8 - if (idx := raw_output.find(":", 0, STOP_SEARCHING_AT_IDX)) != -1: - raw_output = raw_output[idx + 1:] - - output = f"{character_name}: {raw_output.strip()}" + output = _run_raw_inference(model, tokenizer, prompt, user_message).strip() + logger.debug("_run_raw_inference returned `%s` after .strip()", output) # If there's enough space, the model will likely generate more than just its # own message, so we need to trim that out and just remove the first # generated message. generated_messages = _parse_messages_from_str(output, ["You", character_name]) + logger.debug("Generated messages is `%s`", generated_messages) bot_message = generated_messages[0] logger.info("Generated message: `%s`", bot_message) @@ -287,8 +258,8 @@ def _build_gradio_ui_for(model: t.Any, tokenizer: t.Any) -> t.Any: with gr.Row(): submit_btn = gr.Button("Submit input") submit_fn = lambda context, history, character_name, user_message: _gr_run_inference( - model, tokenizer, context, history, character_name, - user_message) + model, tokenizer, context, history, character_name, user_message + ) submit_btn.click( fn=submit_fn, inputs=[context, history, character_name, user_message], @@ -296,8 +267,8 @@ def _build_gradio_ui_for(model: t.Any, tokenizer: t.Any) -> t.Any: regenerate_btn = gr.Button("Regenerate last output") regenerate_fn = lambda context, history, character_name, user_message: _gr_regenerate_last_output( - model, tokenizer, context, history, character_name, - user_message) + model, tokenizer, context, history, character_name, user_message + ) regenerate_btn.click( fn=regenerate_fn, inputs=[context, history, character_name, user_message], @@ -305,8 +276,8 @@ def _build_gradio_ui_for(model: t.Any, tokenizer: t.Any) -> t.Any: undo_btn = gr.Button("Undo last exchange") undo_btn.click(fn=_gr_undo, - inputs=[history], - outputs=[history, history_text]) + inputs=[history], + outputs=[history, history_text]) return interface