mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
33 lines
990 B
33 lines
990 B
import time
|
|
|
|
import torch
|
|
from sentencepiece import SentencePieceProcessor
|
|
from transformers import AutoModelForCausalLM
|
|
from utils import get_defualt_parser, inference, print_output
|
|
|
|
if __name__ == "__main__":
|
|
parser = get_defualt_parser()
|
|
args = parser.parse_args()
|
|
start = time.time()
|
|
torch.set_default_dtype(torch.bfloat16)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
args.pretrained,
|
|
trust_remote_code=True,
|
|
device_map="auto",
|
|
torch_dtype=torch.bfloat16,
|
|
)
|
|
sp = SentencePieceProcessor(model_file=args.tokenizer)
|
|
for text in args.text:
|
|
output = inference(
|
|
model,
|
|
sp,
|
|
text,
|
|
max_new_tokens=args.max_new_tokens,
|
|
do_sample=args.do_sample,
|
|
temperature=args.temperature,
|
|
top_k=args.top_k,
|
|
top_p=args.top_p,
|
|
)
|
|
print_output(text, sp.decode(output))
|
|
print(f"Overall time: {time.time() - start} seconds.")
|