feat: add options to control text generation on the UI
This commit is contained in:
parent
5e34b105dc
commit
4a1784f8a1
|
@ -102,26 +102,34 @@ def _build_model_and_tokenizer_for(args: argparse.Namespace,
|
||||||
|
|
||||||
|
|
||||||
def _run_raw_inference(model: t.Any, tokenizer: t.Any, prompt: str,
|
def _run_raw_inference(model: t.Any, tokenizer: t.Any, prompt: str,
|
||||||
user_message: str) -> str:
|
user_message: str,
|
||||||
|
sampl_new_tokens, sampl_top_k, sampl_top_p,
|
||||||
|
cs_new_tokens, cs_alpha, cs_top_k, bad_words_str) -> str:
|
||||||
'''Runs raw inference on the model, and returns just the generated text.'''
|
'''Runs raw inference on the model, and returns just the generated text.'''
|
||||||
|
|
||||||
# First, sampling-based generation.
|
# First, sampling-based generation.
|
||||||
|
bad_words_ids = [
|
||||||
|
tokenizer(bad_word, add_special_tokens=True).input_ids
|
||||||
|
for bad_word in bad_words_str.split(';')
|
||||||
|
]
|
||||||
|
|
||||||
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to("cuda")
|
input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to("cuda")
|
||||||
logits = model.generate(
|
logits = model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
max_new_tokens=32,
|
max_new_tokens=int(sampl_new_tokens),
|
||||||
top_k=50,
|
top_k=int(sampl_top_k),
|
||||||
top_p=0.90,
|
top_p=sampl_top_p,
|
||||||
|
bad_words_ids=bad_words_ids,
|
||||||
)
|
)
|
||||||
output = tokenizer.decode(logits[0], skip_special_tokens=True)
|
output = tokenizer.decode(logits[0], skip_special_tokens=True)
|
||||||
|
|
||||||
# Then, contrastive search.
|
# Then, contrastive search.
|
||||||
input_ids = tokenizer(output, return_tensors="pt").input_ids.to("cuda")
|
input_ids = tokenizer(output, return_tensors="pt").input_ids.to("cuda")
|
||||||
logits = model.generate(input_ids,
|
logits = model.generate(input_ids,
|
||||||
max_new_tokens=96,
|
max_new_tokens=int(cs_new_tokens),
|
||||||
penalty_alpha=0.6,
|
penalty_alpha=cs_alpha,
|
||||||
top_k=6)
|
top_k=int(cs_top_k))
|
||||||
|
|
||||||
# Then, we trim out the input prompt from the generated output.
|
# Then, we trim out the input prompt from the generated output.
|
||||||
output = tokenizer.decode(logits[0], skip_special_tokens=True)
|
output = tokenizer.decode(logits[0], skip_special_tokens=True)
|
||||||
|
@ -178,7 +186,9 @@ def _serialize_chat_history(history: list[str]) -> str:
|
||||||
|
|
||||||
def _gr_run_inference(model: t.Any, tokenizer: t.Any, context: str,
|
def _gr_run_inference(model: t.Any, tokenizer: t.Any, context: str,
|
||||||
history: list[str], character_name: str,
|
history: list[str], character_name: str,
|
||||||
user_message: str) -> t.Tuple[list[str], str]:
|
user_message: str,
|
||||||
|
sampl_new_tokens, sampl_top_k, sampl_top_p,
|
||||||
|
cs_new_tokens, cs_alpha, cs_top_k, bad_words_str) -> t.Tuple[list[str], str]:
|
||||||
'''
|
'''
|
||||||
With `context` and `history` prompt-engineered into the model's input, feed
|
With `context` and `history` prompt-engineered into the model's input, feed
|
||||||
it `user_message` and return everything the Gradio UI expects.
|
it `user_message` and return everything the Gradio UI expects.
|
||||||
|
@ -189,8 +199,10 @@ def _gr_run_inference(model: t.Any, tokenizer: t.Any, context: str,
|
||||||
prompt = "\n".join(
|
prompt = "\n".join(
|
||||||
[context, "", *history, f"You: {user_message}", f"{character_name}: "])
|
[context, "", *history, f"You: {user_message}", f"{character_name}: "])
|
||||||
|
|
||||||
output = _run_raw_inference(model, tokenizer, prompt, user_message).strip()
|
output = _run_raw_inference(model, tokenizer, prompt, user_message,
|
||||||
logger.debug("_run_raw_inference returned `%s` after .strip()", output)
|
sampl_new_tokens, sampl_top_k, sampl_top_p,
|
||||||
|
cs_new_tokens, cs_alpha, cs_top_k, bad_words_str)
|
||||||
|
logger.debug("_run_raw_inference returned `%s`", output)
|
||||||
|
|
||||||
# If there's enough space, the model will likely generate more than just its
|
# If there's enough space, the model will likely generate more than just its
|
||||||
# own message, so we need to trim that out and just remove the first
|
# own message, so we need to trim that out and just remove the first
|
||||||
|
@ -210,11 +222,14 @@ def _gr_run_inference(model: t.Any, tokenizer: t.Any, context: str,
|
||||||
|
|
||||||
def _gr_regenerate_last_output(model: t.Any, tokenizer: t.Any, context: str,
|
def _gr_regenerate_last_output(model: t.Any, tokenizer: t.Any, context: str,
|
||||||
history: list[str], character_name: str,
|
history: list[str], character_name: str,
|
||||||
user_message: str) -> t.Tuple[list[str], str]:
|
user_message: str,
|
||||||
|
sampl_new_tokens, sampl_top_k, sampl_top_p,
|
||||||
|
cs_new_tokens, cs_alpha, cs_top_k, bad_words_str) -> t.Tuple[list[str], str]:
|
||||||
history_without_last_message = history[:-2]
|
history_without_last_message = history[:-2]
|
||||||
return _gr_run_inference(model, tokenizer, context,
|
return _gr_run_inference(model, tokenizer, context,
|
||||||
history_without_last_message, character_name,
|
history_without_last_message, character_name,
|
||||||
user_message)
|
user_message, sampl_new_tokens, sampl_top_k, sampl_top_p,
|
||||||
|
cs_new_tokens, cs_alpha, cs_top_k, bad_words_str)
|
||||||
|
|
||||||
|
|
||||||
def _gr_undo(history: list[str]) -> t.Tuple[list[str], str]:
|
def _gr_undo(history: list[str]) -> t.Tuple[list[str], str]:
|
||||||
|
@ -248,6 +263,14 @@ def _build_gradio_ui_for(model: t.Any, tokenizer: t.Any) -> t.Any:
|
||||||
"Insert the context of your character here, such as personality and scenario. Think of this as akin to CAI's short and long description put together.",
|
"Insert the context of your character here, such as personality and scenario. Think of this as akin to CAI's short and long description put together.",
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sampl_new_tokens = gr.Number(label="tokens (s)", value=16)
|
||||||
|
sampl_top_k = gr.Number(label="top k (s)", value=40)
|
||||||
|
sampl_top_p = gr.Number(label="top p (s)", value=0.9)
|
||||||
|
cs_new_tokens = gr.Number(label="tokens (cs)", value=112)
|
||||||
|
cs_alpha = gr.Number(label="alpha", value=0.6)
|
||||||
|
cs_top_k = gr.Number(label="top k (cs)", value=6)
|
||||||
|
bad_words_str = gr.Textbox(label="(';' separated) bad words", value="....;.....;......;.......;........;.........")
|
||||||
history_text = gr.Textbox(
|
history_text = gr.Textbox(
|
||||||
label="Output",
|
label="Output",
|
||||||
lines=4,
|
lines=4,
|
||||||
|
@ -257,21 +280,21 @@ def _build_gradio_ui_for(model: t.Any, tokenizer: t.Any) -> t.Any:
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
submit_btn = gr.Button("Submit input")
|
submit_btn = gr.Button("Submit input")
|
||||||
submit_fn = lambda context, history, character_name, user_message: _gr_run_inference(
|
submit_fn = lambda context, history, character_name, user_message, sampl_new_tokens, sampl_top_k, sampl_top_p, cs_new_tokens, cs_alpha, cs_top_k, bad_words_str: _gr_run_inference(
|
||||||
model, tokenizer, context, history, character_name, user_message
|
model, tokenizer, context, history, character_name, user_message, sampl_new_tokens, sampl_top_k, sampl_top_p, cs_new_tokens, cs_alpha, cs_top_k, bad_words_str
|
||||||
)
|
)
|
||||||
submit_btn.click(
|
submit_btn.click(
|
||||||
fn=submit_fn,
|
fn=submit_fn,
|
||||||
inputs=[context, history, character_name, user_message],
|
inputs=[context, history, character_name, user_message, sampl_new_tokens, sampl_top_k, sampl_top_p, cs_new_tokens, cs_alpha, cs_top_k, bad_words_str],
|
||||||
outputs=[history, history_text])
|
outputs=[history, history_text])
|
||||||
|
|
||||||
regenerate_btn = gr.Button("Regenerate last output")
|
regenerate_btn = gr.Button("Regenerate last output")
|
||||||
regenerate_fn = lambda context, history, character_name, user_message: _gr_regenerate_last_output(
|
regenerate_fn = lambda context, history, character_name, user_message, sampl_new_tokens, sampl_top_k, sampl_top_p, cs_new_tokens, cs_alpha, cs_top_k, bad_words_str: _gr_regenerate_last_output(
|
||||||
model, tokenizer, context, history, character_name, user_message
|
model, tokenizer, context, history, character_name, user_message, sampl_new_tokens, sampl_top_k, sampl_top_p, cs_new_tokens, cs_alpha, cs_top_k, bad_words_str
|
||||||
)
|
)
|
||||||
regenerate_btn.click(
|
regenerate_btn.click(
|
||||||
fn=regenerate_fn,
|
fn=regenerate_fn,
|
||||||
inputs=[context, history, character_name, user_message],
|
inputs=[context, history, character_name, user_message, sampl_new_tokens, sampl_top_k, sampl_top_p, cs_new_tokens, cs_alpha, cs_top_k, bad_words_str],
|
||||||
outputs=[history, history_text])
|
outputs=[history, history_text])
|
||||||
|
|
||||||
undo_btn = gr.Button("Undo last exchange")
|
undo_btn = gr.Button("Undo last exchange")
|
||||||
|
|
Loading…
Reference in New Issue