diff --git a/training/inference.py b/training/inference.py index a44a1f5..3314a14 100755 --- a/training/inference.py +++ b/training/inference.py @@ -102,26 +102,34 @@ def _build_model_and_tokenizer_for(args: argparse.Namespace, def _run_raw_inference(model: t.Any, tokenizer: t.Any, prompt: str, - user_message: str) -> str: + user_message: str, + sampl_new_tokens, sampl_top_k, sampl_top_p, + cs_new_tokens, cs_alpha, cs_top_k, bad_words_str) -> str: '''Runs raw inference on the model, and returns just the generated text.''' # First, sampling-based generation. + bad_words_ids = [ + tokenizer(bad_word, add_special_tokens=True).input_ids + for bad_word in bad_words_str.split(';') + ] + input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to("cuda") logits = model.generate( input_ids, do_sample=True, - max_new_tokens=32, - top_k=50, - top_p=0.90, + max_new_tokens=int(sampl_new_tokens), + top_k=int(sampl_top_k), + top_p=sampl_top_p, + bad_words_ids=bad_words_ids, ) output = tokenizer.decode(logits[0], skip_special_tokens=True) # Then, contrastive search. input_ids = tokenizer(output, return_tensors="pt").input_ids.to("cuda") logits = model.generate(input_ids, - max_new_tokens=96, - penalty_alpha=0.6, - top_k=6) + max_new_tokens=int(cs_new_tokens), + penalty_alpha=cs_alpha, + top_k=int(cs_top_k)) # Then, we trim out the input prompt from the generated output. output = tokenizer.decode(logits[0], skip_special_tokens=True) @@ -178,7 +186,9 @@ def _serialize_chat_history(history: list[str]) -> str: def _gr_run_inference(model: t.Any, tokenizer: t.Any, context: str, history: list[str], character_name: str, - user_message: str) -> t.Tuple[list[str], str]: + user_message: str, + sampl_new_tokens, sampl_top_k, sampl_top_p, + cs_new_tokens, cs_alpha, cs_top_k, bad_words_str) -> t.Tuple[list[str], str]: ''' With `context` and `history` prompt-engineered into the model's input, feed it `user_message` and return everything the Gradio UI expects. @@ -189,8 +199,10 @@ def _gr_run_inference(model: t.Any, tokenizer: t.Any, context: str, prompt = "\n".join( [context, "", *history, f"You: {user_message}", f"{character_name}: "]) - output = _run_raw_inference(model, tokenizer, prompt, user_message).strip() - logger.debug("_run_raw_inference returned `%s` after .strip()", output) + output = _run_raw_inference(model, tokenizer, prompt, user_message, + sampl_new_tokens, sampl_top_k, sampl_top_p, + cs_new_tokens, cs_alpha, cs_top_k, bad_words_str) + logger.debug("_run_raw_inference returned `%s`", output) # If there's enough space, the model will likely generate more than just its # own message, so we need to trim that out and just remove the first @@ -210,11 +222,14 @@ def _gr_run_inference(model: t.Any, tokenizer: t.Any, context: str, def _gr_regenerate_last_output(model: t.Any, tokenizer: t.Any, context: str, history: list[str], character_name: str, - user_message: str) -> t.Tuple[list[str], str]: + user_message: str, + sampl_new_tokens, sampl_top_k, sampl_top_p, + cs_new_tokens, cs_alpha, cs_top_k, bad_words_str) -> t.Tuple[list[str], str]: history_without_last_message = history[:-2] return _gr_run_inference(model, tokenizer, context, history_without_last_message, character_name, - user_message) + user_message, sampl_new_tokens, sampl_top_k, sampl_top_p, + cs_new_tokens, cs_alpha, cs_top_k, bad_words_str) def _gr_undo(history: list[str]) -> t.Tuple[list[str], str]: @@ -248,6 +263,14 @@ def _build_gradio_ui_for(model: t.Any, tokenizer: t.Any) -> t.Any: "Insert the context of your character here, such as personality and scenario. Think of this as akin to CAI's short and long description put together.", interactive=True, ) + + sampl_new_tokens = gr.Number(label="tokens (s)", value=16) + sampl_top_k = gr.Number(label="top k (s)", value=40) + sampl_top_p = gr.Number(label="top p (s)", value=0.9) + cs_new_tokens = gr.Number(label="tokens (cs)", value=112) + cs_alpha = gr.Number(label="alpha", value=0.6) + cs_top_k = gr.Number(label="top k (cs)", value=6) + bad_words_str = gr.Textbox(label="(';' separated) bad words", value="....;.....;......;.......;........;.........") history_text = gr.Textbox( label="Output", lines=4, @@ -257,21 +280,21 @@ def _build_gradio_ui_for(model: t.Any, tokenizer: t.Any) -> t.Any: with gr.Row(): submit_btn = gr.Button("Submit input") - submit_fn = lambda context, history, character_name, user_message: _gr_run_inference( - model, tokenizer, context, history, character_name, user_message + submit_fn = lambda context, history, character_name, user_message, sampl_new_tokens, sampl_top_k, sampl_top_p, cs_new_tokens, cs_alpha, cs_top_k, bad_words_str: _gr_run_inference( + model, tokenizer, context, history, character_name, user_message, sampl_new_tokens, sampl_top_k, sampl_top_p, cs_new_tokens, cs_alpha, cs_top_k, bad_words_str ) submit_btn.click( fn=submit_fn, - inputs=[context, history, character_name, user_message], + inputs=[context, history, character_name, user_message, sampl_new_tokens, sampl_top_k, sampl_top_p, cs_new_tokens, cs_alpha, cs_top_k, bad_words_str], outputs=[history, history_text]) regenerate_btn = gr.Button("Regenerate last output") - regenerate_fn = lambda context, history, character_name, user_message: _gr_regenerate_last_output( - model, tokenizer, context, history, character_name, user_message + regenerate_fn = lambda context, history, character_name, user_message, sampl_new_tokens, sampl_top_k, sampl_top_p, cs_new_tokens, cs_alpha, cs_top_k, bad_words_str: _gr_regenerate_last_output( + model, tokenizer, context, history, character_name, user_message, sampl_new_tokens, sampl_top_k, sampl_top_p, cs_new_tokens, cs_alpha, cs_top_k, bad_words_str ) regenerate_btn.click( fn=regenerate_fn, - inputs=[context, history, character_name, user_message], + inputs=[context, history, character_name, user_message, sampl_new_tokens, sampl_top_k, sampl_top_p, cs_new_tokens, cs_alpha, cs_top_k, bad_words_str], outputs=[history, history_text]) undo_btn = gr.Button("Undo last exchange")