You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/examples/language/grok-1/inference.py

46 lines
1.4 KiB

import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils import get_defualt_parser, inference, print_output
if __name__ == "__main__":
parser = get_defualt_parser()
args = parser.parse_args()
start = time.time()
torch.set_default_dtype(torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model.eval()
init_time = time.time() - start
for text in args.text:
output = inference(
model,
tokenizer,
text,
max_new_tokens=args.max_new_tokens,
do_sample=args.do_sample,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
print_output(text, tokenizer.decode(output))
overall_time = time.time() - start
gen_latency = overall_time - init_time
avg_gen_latency = gen_latency / len(args.text)
print(
f"Initializing time: {init_time:.2f} seconds.\n"
f"Overall time: {overall_time:.2f} seconds. \n"
f"Generation latency: {gen_latency:.2f} seconds. \n"
f"Average generation latency: {avg_gen_latency:.2f} seconds. \n"
)