2024-03-07 06:58:56 +00:00
import argparse
from colossal_llama2 . utils . stream_chat_patch import streaming_chat
2024-06-14 05:02:37 +00:00
from transformers import AutoModelForCausalLM , AutoTokenizer
2024-03-07 06:58:56 +00:00
SYSTEM = " A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human ' s questions. "
2024-06-14 05:02:37 +00:00
2024-03-07 06:58:56 +00:00
def main ( args ) :
model = AutoModelForCausalLM . from_pretrained ( args . model_path ) . cuda ( ) . eval ( )
tokenizer = AutoTokenizer . from_pretrained ( args . tokenizer_path )
past_key_values , history = None , [ ]
roles = [ " " , " Human " , " Assistant " ]
history = [ ]
history . append ( { " role " : roles [ 0 ] , " message " : SYSTEM } )
while True :
input_query = input ( f " \n { roles [ 1 ] } : " )
if input_query . strip ( ) == " exit " :
break
if input_query . strip ( ) == " clear " :
past_key_values , history = None , [ ]
continue
print ( f " \n { roles [ 2 ] } : " , end = " " )
gen_len = 0
for response , history , past_key_values in streaming_chat (
2024-06-14 05:02:37 +00:00
model ,
tokenizer ,
input_query ,
history = history ,
roles = roles ,
temperature = args . temperature ,
top_p = args . top_p ,
top_k = args . top_k ,
do_sample = args . do_sample ,
length_penalty = args . length_penalty ,
max_new_tokens = args . max_new_tokens ,
2024-03-07 06:58:56 +00:00
past_key_values = past_key_values ,
2024-06-14 05:02:37 +00:00
return_past_key_values = True ,
) :
2024-03-07 06:58:56 +00:00
output = response [ gen_len : ]
print ( output , end = " " , flush = True )
gen_len = len ( response )
2024-06-14 05:02:37 +00:00
2024-03-07 06:58:56 +00:00
if __name__ == " __main__ " :
parser = argparse . ArgumentParser ( )
2024-06-14 05:02:37 +00:00
parser . add_argument ( " --model_path " , type = str , default = None , help = " path to chat version model " )
parser . add_argument ( " --tokenizer_path " , type = str , default = None , help = " path to chat version tokenizer " )
parser . add_argument ( " --temperature " , type = float , default = 0.8 , help = " set temperature " )
parser . add_argument ( " --top_p " , type = float , default = 0.95 , help = " set top p value " )
parser . add_argument ( " --top_k " , type = int , default = 50 , help = " set top k value " )
parser . add_argument ( " --do_sample " , type = bool , default = True , help = " whether turn on do_sample or not " )
parser . add_argument ( " --length_penalty " , type = float , default = 1.2 , help = " set length penalty " )
parser . add_argument ( " --max_new_tokens " , type = int , default = 512 , help = " set max new tokens " )
2024-03-07 06:58:56 +00:00
args = parser . parse_args ( )
2024-06-14 05:02:37 +00:00
main ( args )