feat: add options to control text generation on the UI

This commit is contained in:
11b 2022-12-25 15:41:10 -03:00
parent 5e34b105dc
commit 4a1784f8a1
1 changed files with 41 additions and 18 deletions

View File

@ -102,26 +102,34 @@ def _build_model_and_tokenizer_for(args: argparse.Namespace,
def _run_raw_inference(model: t.Any, tokenizer: t.Any, prompt: str, def _run_raw_inference(model: t.Any, tokenizer: t.Any, prompt: str,
user_message: str) -> str: user_message: str,
sampl_new_tokens, sampl_top_k, sampl_top_p,
cs_new_tokens, cs_alpha, cs_top_k, bad_words_str) -> str:
'''Runs raw inference on the model, and returns just the generated text.''' '''Runs raw inference on the model, and returns just the generated text.'''
# First, sampling-based generation. # First, sampling-based generation.
bad_words_ids = [
tokenizer(bad_word, add_special_tokens=True).input_ids
for bad_word in bad_words_str.split(';')
]
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to("cuda") input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to("cuda")
logits = model.generate( logits = model.generate(
input_ids, input_ids,
do_sample=True, do_sample=True,
max_new_tokens=32, max_new_tokens=int(sampl_new_tokens),
top_k=50, top_k=int(sampl_top_k),
top_p=0.90, top_p=sampl_top_p,
bad_words_ids=bad_words_ids,
) )
output = tokenizer.decode(logits[0], skip_special_tokens=True) output = tokenizer.decode(logits[0], skip_special_tokens=True)
# Then, contrastive search. # Then, contrastive search.
input_ids = tokenizer(output, return_tensors="pt").input_ids.to("cuda") input_ids = tokenizer(output, return_tensors="pt").input_ids.to("cuda")
logits = model.generate(input_ids, logits = model.generate(input_ids,
max_new_tokens=96, max_new_tokens=int(cs_new_tokens),
penalty_alpha=0.6, penalty_alpha=cs_alpha,
top_k=6) top_k=int(cs_top_k))
# Then, we trim out the input prompt from the generated output. # Then, we trim out the input prompt from the generated output.
output = tokenizer.decode(logits[0], skip_special_tokens=True) output = tokenizer.decode(logits[0], skip_special_tokens=True)
@ -178,7 +186,9 @@ def _serialize_chat_history(history: list[str]) -> str:
def _gr_run_inference(model: t.Any, tokenizer: t.Any, context: str, def _gr_run_inference(model: t.Any, tokenizer: t.Any, context: str,
history: list[str], character_name: str, history: list[str], character_name: str,
user_message: str) -> t.Tuple[list[str], str]: user_message: str,
sampl_new_tokens, sampl_top_k, sampl_top_p,
cs_new_tokens, cs_alpha, cs_top_k, bad_words_str) -> t.Tuple[list[str], str]:
''' '''
With `context` and `history` prompt-engineered into the model's input, feed With `context` and `history` prompt-engineered into the model's input, feed
it `user_message` and return everything the Gradio UI expects. it `user_message` and return everything the Gradio UI expects.
@ -189,8 +199,10 @@ def _gr_run_inference(model: t.Any, tokenizer: t.Any, context: str,
prompt = "\n".join( prompt = "\n".join(
[context, "", *history, f"You: {user_message}", f"{character_name}: "]) [context, "", *history, f"You: {user_message}", f"{character_name}: "])
output = _run_raw_inference(model, tokenizer, prompt, user_message).strip() output = _run_raw_inference(model, tokenizer, prompt, user_message,
logger.debug("_run_raw_inference returned `%s` after .strip()", output) sampl_new_tokens, sampl_top_k, sampl_top_p,
cs_new_tokens, cs_alpha, cs_top_k, bad_words_str)
logger.debug("_run_raw_inference returned `%s`", output)
# If there's enough space, the model will likely generate more than just its # If there's enough space, the model will likely generate more than just its
# own message, so we need to trim that out and just remove the first # own message, so we need to trim that out and just remove the first
@ -210,11 +222,14 @@ def _gr_run_inference(model: t.Any, tokenizer: t.Any, context: str,
def _gr_regenerate_last_output(model: t.Any, tokenizer: t.Any, context: str, def _gr_regenerate_last_output(model: t.Any, tokenizer: t.Any, context: str,
history: list[str], character_name: str, history: list[str], character_name: str,
user_message: str) -> t.Tuple[list[str], str]: user_message: str,
sampl_new_tokens, sampl_top_k, sampl_top_p,
cs_new_tokens, cs_alpha, cs_top_k, bad_words_str) -> t.Tuple[list[str], str]:
history_without_last_message = history[:-2] history_without_last_message = history[:-2]
return _gr_run_inference(model, tokenizer, context, return _gr_run_inference(model, tokenizer, context,
history_without_last_message, character_name, history_without_last_message, character_name,
user_message) user_message, sampl_new_tokens, sampl_top_k, sampl_top_p,
cs_new_tokens, cs_alpha, cs_top_k, bad_words_str)
def _gr_undo(history: list[str]) -> t.Tuple[list[str], str]: def _gr_undo(history: list[str]) -> t.Tuple[list[str], str]:
@ -248,6 +263,14 @@ def _build_gradio_ui_for(model: t.Any, tokenizer: t.Any) -> t.Any:
"Insert the context of your character here, such as personality and scenario. Think of this as akin to CAI's short and long description put together.", "Insert the context of your character here, such as personality and scenario. Think of this as akin to CAI's short and long description put together.",
interactive=True, interactive=True,
) )
sampl_new_tokens = gr.Number(label="tokens (s)", value=16)
sampl_top_k = gr.Number(label="top k (s)", value=40)
sampl_top_p = gr.Number(label="top p (s)", value=0.9)
cs_new_tokens = gr.Number(label="tokens (cs)", value=112)
cs_alpha = gr.Number(label="alpha", value=0.6)
cs_top_k = gr.Number(label="top k (cs)", value=6)
bad_words_str = gr.Textbox(label="(';' separated) bad words", value="....;.....;......;.......;........;.........")
history_text = gr.Textbox( history_text = gr.Textbox(
label="Output", label="Output",
lines=4, lines=4,
@ -257,21 +280,21 @@ def _build_gradio_ui_for(model: t.Any, tokenizer: t.Any) -> t.Any:
with gr.Row(): with gr.Row():
submit_btn = gr.Button("Submit input") submit_btn = gr.Button("Submit input")
submit_fn = lambda context, history, character_name, user_message: _gr_run_inference( submit_fn = lambda context, history, character_name, user_message, sampl_new_tokens, sampl_top_k, sampl_top_p, cs_new_tokens, cs_alpha, cs_top_k, bad_words_str: _gr_run_inference(
model, tokenizer, context, history, character_name, user_message model, tokenizer, context, history, character_name, user_message, sampl_new_tokens, sampl_top_k, sampl_top_p, cs_new_tokens, cs_alpha, cs_top_k, bad_words_str
) )
submit_btn.click( submit_btn.click(
fn=submit_fn, fn=submit_fn,
inputs=[context, history, character_name, user_message], inputs=[context, history, character_name, user_message, sampl_new_tokens, sampl_top_k, sampl_top_p, cs_new_tokens, cs_alpha, cs_top_k, bad_words_str],
outputs=[history, history_text]) outputs=[history, history_text])
regenerate_btn = gr.Button("Regenerate last output") regenerate_btn = gr.Button("Regenerate last output")
regenerate_fn = lambda context, history, character_name, user_message: _gr_regenerate_last_output( regenerate_fn = lambda context, history, character_name, user_message, sampl_new_tokens, sampl_top_k, sampl_top_p, cs_new_tokens, cs_alpha, cs_top_k, bad_words_str: _gr_regenerate_last_output(
model, tokenizer, context, history, character_name, user_message model, tokenizer, context, history, character_name, user_message, sampl_new_tokens, sampl_top_k, sampl_top_p, cs_new_tokens, cs_alpha, cs_top_k, bad_words_str
) )
regenerate_btn.click( regenerate_btn.click(
fn=regenerate_fn, fn=regenerate_fn,
inputs=[context, history, character_name, user_message], inputs=[context, history, character_name, user_message, sampl_new_tokens, sampl_top_k, sampl_top_p, cs_new_tokens, cs_alpha, cs_top_k, bad_words_str],
outputs=[history, history_text]) outputs=[history, history_text])
undo_btn = gr.Button("Undo last exchange") undo_btn = gr.Button("Undo last exchange")