2022-12-18 21:26:32 +01:00
#!/usr/bin/env python3
import argparse
import logging
2022-12-24 16:12:55 +01:00
import typing as t
import re
2022-12-18 21:26:32 +01:00
import torch
import transformers
2022-12-24 16:12:55 +01:00
import gradio as gr
2022-12-18 21:26:32 +01:00
logger = logging . getLogger ( __name__ )
2022-12-24 16:12:55 +01:00
logging . basicConfig ( level = logging . DEBUG )
# TODO(11b): Type these functions up properly.
2022-12-18 21:26:32 +01:00
def main ( ) - > None :
2022-12-24 16:12:55 +01:00
''' Script entrypoint. '''
args = _parse_args_from_argv ( )
# TODO(11b): We don't have the bot name at this point, since it's dynamic
# on the UI, so we can't build `bad_word_ids` as perfectly as I'd like. See
# if we can improve this later.
model , tokenizer = _build_model_and_tokenizer_for ( args , bot_name = " " )
ui = _build_gradio_ui_for ( model , tokenizer )
ui . launch ( server_port = 3000 , share = False )
def _parse_args_from_argv ( ) - > argparse . Namespace :
''' Parses arguments coming in from the command line. '''
2022-12-18 21:26:32 +01:00
parser = argparse . ArgumentParser ( )
2022-12-24 16:12:55 +01:00
parser . add_argument (
" -m " ,
" --model-name " ,
default = " facebook/opt-350m " ,
help = " HuggingFace Transformers model name. " ,
)
2022-12-18 21:26:32 +01:00
parser . add_argument (
" -c " ,
" --checkpoint " ,
2022-12-24 16:12:55 +01:00
help = " Fine-tune checkpoint to load into the base model. Optional. " ,
)
return parser . parse_args ( )
def _build_blacklist_for ( bot_name : str ) - > list [ str ] :
'''
Builds a blacklist for the given bot name .
This is used to stop the model from invoking modules when we haven ' t
prompted it to .
'''
2022-12-18 21:26:32 +01:00
2022-12-24 16:12:55 +01:00
# NOTE(11b): This should _ideally_ be shared with the actual implementations
# inside the package's .core.consts, but for simplicity's sake I'm
# re-implementing here (so there's no need to install the package just to
# run inference).
pdm_prefix = f " { bot_name } ' s Persona: "
2022-12-18 21:26:32 +01:00
2022-12-24 16:12:55 +01:00
# Not sure why, but the pre-trained OPT likes to generate these and it leaks
# out to the fine-tuned models as well.
bad_opt_generations = [ " ___ " , " ____ " , " _____ " ]
2022-12-25 19:37:34 +01:00
# And Pythia likes to do this.
bad_pythia_generations = [ " .... " ]
return [ pdm_prefix , * bad_opt_generations , * bad_pythia_generations ]
2022-12-24 16:12:55 +01:00
def _build_model_and_tokenizer_for ( args : argparse . Namespace ,
bot_name : str ) - > t . Tuple [ t . Any , t . Any ] :
''' Sets up the model and accompanying tokenizer. '''
2022-12-18 21:26:32 +01:00
logger . info ( f " Loading tokenizer for { args . model_name } " )
tokenizer = transformers . AutoTokenizer . from_pretrained ( args . model_name )
state_dict = None
if args . checkpoint is not None :
logger . info ( f " Loading checkpoint from { args . checkpoint } " )
2022-12-24 16:12:55 +01:00
# NOTE(11b): `.pop("model")` is specific to checkpoints saved by
# the ColossalAI helper. If using a regular HF Transformers checkpoint,
# comment that out.
2022-12-18 21:26:32 +01:00
state_dict = torch . load ( args . checkpoint ,
map_location = " cuda " ) . pop ( " model " )
2022-12-25 19:37:34 +01:00
tokenizer_kwargs = { " add_special_tokens " : False }
if " facebook/opt- " in args . model_name :
tokenizer_kwargs [ " add_prefix_space " ] = True
2022-12-18 21:26:32 +01:00
bad_words_ids = [
2022-12-25 19:37:34 +01:00
tokenizer ( bad_word , * * tokenizer_kwargs ) . input_ids
2022-12-24 16:12:55 +01:00
for bad_word in _build_blacklist_for ( bot_name )
2022-12-18 21:26:32 +01:00
]
2022-12-24 16:12:55 +01:00
logger . info ( f " Loading the { args . model_name } model " )
2022-12-18 21:26:32 +01:00
model = transformers . AutoModelForCausalLM . from_pretrained (
args . model_name , state_dict = state_dict , bad_words_ids = bad_words_ids )
model . eval ( ) . half ( ) . to ( " cuda " )
2022-12-24 16:12:55 +01:00
logger . info ( " Model and tokenizer are ready " )
return model , tokenizer
def _run_raw_inference ( model : t . Any , tokenizer : t . Any , prompt : str ,
user_message : str ) - > str :
''' Runs raw inference on the model, and returns just the generated text. '''
# First, sampling-based generation.
input_ids = tokenizer ( prompt , return_tensors = ' pt ' ) . input_ids . to ( " cuda " )
logits = model . generate (
input_ids ,
do_sample = True ,
2022-12-25 19:37:34 +01:00
max_new_tokens = 32 ,
2022-12-24 16:12:55 +01:00
top_k = 50 ,
top_p = 0.90 ,
)
output = tokenizer . decode ( logits [ 0 ] , skip_special_tokens = True )
# Then, contrastive search.
input_ids = tokenizer ( output , return_tensors = " pt " ) . input_ids . to ( " cuda " )
logits = model . generate ( input_ids ,
2022-12-25 19:37:34 +01:00
max_new_tokens = 96 ,
2022-12-24 16:12:55 +01:00
penalty_alpha = 0.6 ,
top_k = 6 )
2022-12-25 19:37:34 +01:00
# Then, we trim out the input prompt from the generated output.
2022-12-24 16:12:55 +01:00
output = tokenizer . decode ( logits [ 0 ] , skip_special_tokens = True )
if ( idx := prompt . rfind ( user_message ) ) != - 1 :
2022-12-25 19:37:34 +01:00
trimmed_output = output [ idx + len ( user_message ) : ] . strip ( )
2022-12-24 16:12:55 +01:00
return trimmed_output
else :
raise ValueError ( " Couldn ' t find user message in the prompt. What? " )
BAD_CHARS_FOR_REGEX_REGEX = re . compile ( r " [- \ / \\ ^$*+?.()|[ \ ] {} ] " )
def _sanitize_string_for_use_in_a_regex ( string : str ) - > str :
''' Sanitizes `string` so it can be used inside of a regexp. '''
return BAD_CHARS_FOR_REGEX_REGEX . sub ( r " \\ \ g<0> " , string )
def _parse_messages_from_str ( string : str , names : list [ str ] ) - > list [ str ] :
'''
Given a big string containing raw chat history , this function attempts to
parse it out into a list where each item is an individual message .
'''
sanitized_names = [
_sanitize_string_for_use_in_a_regex ( name ) for name in names
]
speaker_regex = re . compile ( rf " ^( { ' | ' . join ( sanitized_names ) } ): " ,
re . MULTILINE )
message_start_indexes = [ ]
for match in speaker_regex . finditer ( string ) :
message_start_indexes . append ( match . start ( ) )
2022-12-25 19:37:34 +01:00
if len ( message_start_indexes ) < 2 :
# Single message in the string.
return [ string . strip ( ) ]
2022-12-24 16:12:55 +01:00
prev_start_idx = message_start_indexes [ 0 ]
messages = [ ]
for start_idx in message_start_indexes [ 1 : ] :
message = string [ prev_start_idx : start_idx ] . strip ( )
messages . append ( message )
prev_start_idx = start_idx
return messages
def _serialize_chat_history ( history : list [ str ] ) - > str :
''' Given a structured chat history object, collapses it down to a string. '''
return " \n " . join ( history )
def _gr_run_inference ( model : t . Any , tokenizer : t . Any , context : str ,
history : list [ str ] , character_name : str ,
user_message : str ) - > t . Tuple [ list [ str ] , str ] :
'''
With ` context ` and ` history ` prompt - engineered into the model ' s input, feed
it ` user_message ` and return everything the Gradio UI expects .
'''
# TODO(11b): Lots of assumptions to fix here. We need to make sure
# everything fits, we need to use "You" from the `.core.consts` module, etc.
prompt = " \n " . join (
[ context , " " , * history , f " You: { user_message } " , f " { character_name } : " ] )
2022-12-25 19:37:34 +01:00
output = _run_raw_inference ( model , tokenizer , prompt , user_message ) . strip ( )
logger . debug ( " _run_raw_inference returned ` %s ` after .strip() " , output )
2022-12-24 16:12:55 +01:00
# If there's enough space, the model will likely generate more than just its
# own message, so we need to trim that out and just remove the first
# generated message.
generated_messages = _parse_messages_from_str ( output ,
[ " You " , character_name ] )
2022-12-25 19:37:34 +01:00
logger . debug ( " Generated messages is ` %s ` " , generated_messages )
2022-12-24 16:12:55 +01:00
bot_message = generated_messages [ 0 ]
logger . info ( " Generated message: ` %s ` " , bot_message )
history . append ( f " You: { user_message } " )
history . append ( bot_message )
serialized_history = _serialize_chat_history ( history )
return history , serialized_history
def _gr_regenerate_last_output ( model : t . Any , tokenizer : t . Any , context : str ,
history : list [ str ] , character_name : str ,
user_message : str ) - > t . Tuple [ list [ str ] , str ] :
history_without_last_message = history [ : - 2 ]
return _gr_run_inference ( model , tokenizer , context ,
history_without_last_message , character_name ,
user_message )
def _gr_undo ( history : list [ str ] ) - > t . Tuple [ list [ str ] , str ] :
updated_history = history [ : - 2 ]
return updated_history , _serialize_chat_history ( updated_history )
def _build_gradio_ui_for ( model : t . Any , tokenizer : t . Any ) - > t . Any :
'''
Builds a Gradio UI to interact with the model . Big thanks to TearGosling for
the initial version of this .
'''
with gr . Blocks ( ) as interface :
history = gr . State ( [ ] )
with gr . Row ( ) :
with gr . Column ( ) :
user_message = gr . Textbox (
label = " Input " ,
placeholder = " Say something here " ,
interactive = True ,
)
character_name = gr . Textbox (
label = " Name of character " ,
placeholder = " Insert the name of your character here " ,
)
context = gr . Textbox (
label = " Long context " ,
lines = 4 ,
placeholder =
" Insert the context of your character here, such as personality and scenario. Think of this as akin to CAI ' s short and long description put together. " ,
interactive = True ,
)
history_text = gr . Textbox (
label = " Output " ,
lines = 4 ,
placeholder = " Your conversation will show up here! " ,
interactive = False ,
2022-12-18 21:26:32 +01:00
)
2022-12-24 16:12:55 +01:00
with gr . Row ( ) :
submit_btn = gr . Button ( " Submit input " )
submit_fn = lambda context , history , character_name , user_message : _gr_run_inference (
2022-12-25 19:37:34 +01:00
model , tokenizer , context , history , character_name , user_message
)
2022-12-24 16:12:55 +01:00
submit_btn . click (
fn = submit_fn ,
inputs = [ context , history , character_name , user_message ] ,
outputs = [ history , history_text ] )
regenerate_btn = gr . Button ( " Regenerate last output " )
regenerate_fn = lambda context , history , character_name , user_message : _gr_regenerate_last_output (
2022-12-25 19:37:34 +01:00
model , tokenizer , context , history , character_name , user_message
)
2022-12-24 16:12:55 +01:00
regenerate_btn . click (
fn = regenerate_fn ,
inputs = [ context , history , character_name , user_message ] ,
outputs = [ history , history_text ] )
undo_btn = gr . Button ( " Undo last exchange " )
undo_btn . click ( fn = _gr_undo ,
2022-12-25 19:37:34 +01:00
inputs = [ history ] ,
outputs = [ history , history_text ] )
2022-12-24 16:12:55 +01:00
return interface
2022-12-18 21:26:32 +01:00
if __name__ == " __main__ " :
main ( )