You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/applications/Colossal-LLaMA-2/stream_chat_example.py

55 lines
2.4 KiB

import os
import argparse
from transformers import AutoTokenizer, AutoModelForCausalLM
from colossal_llama2.utils.stream_chat_patch import streaming_chat
SYSTEM = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
def main(args):
model = AutoModelForCausalLM.from_pretrained(args.model_path).cuda().eval()
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path)
past_key_values, history = None, []
roles = ["", "Human", "Assistant"]
history = []
history.append({"role": roles[0], "message": SYSTEM})
while True:
input_query = input(f"\n{roles[1]}: ")
if input_query.strip() == "exit":
break
if input_query.strip() == "clear":
past_key_values, history = None, []
continue
print(f"\n{roles[2]}: ", end="")
gen_len = 0
for response, history, past_key_values in streaming_chat(
model, tokenizer, input_query, history=history, roles=roles,
temperature = args.temperature,
top_p = args.top_p,
top_k = args.top_k,
do_sample = args.do_sample,
length_penalty = args.length_penalty,
max_new_tokens = args.max_new_tokens,
past_key_values=past_key_values,
return_past_key_values=True):
output = response[gen_len:]
print(output, end="", flush=True)
gen_len = len(response)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default=None, help="path to chat version model")
parser.add_argument('--tokenizer_path', type=str, default=None, help="path to chat version tokenizer")
parser.add_argument('--temperature', type=float, default=0.8, help="set temperature")
parser.add_argument('--top_p', type=float, default=0.95, help="set top p value")
parser.add_argument('--top_k', type=int, default=50, help="set top k value")
parser.add_argument('--do_sample', type=bool, default=True, help="whether turn on do_sample or not")
parser.add_argument('--length_penalty', type=float, default=1.2, help="set length penalty")
parser.add_argument('--max_new_tokens', type=int, default=512, help="set max new tokens")
args = parser.parse_args()
main(args)