ColossalAI/applications/Chat/examples/inference.py

73 lines
2.8 KiB
Python

import argparse
import torch
from coati.models.bloom import BLOOMActor
from coati.models.generation import generate
from coati.models.gpt import GPTActor
from coati.models.llama import LlamaActor
from coati.models.opt import OPTActor
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
def eval(args):
# configure model
if args.model == 'gpt2':
actor = GPTActor(pretrained=args.pretrain)
elif args.model == 'bloom':
actor = BLOOMActor(pretrained=args.pretrain)
elif args.model == 'opt':
actor = OPTActor(pretrained=args.pretrain)
elif args.model == 'llama':
actor = LlamaActor(pretrained=args.pretrain)
else:
raise ValueError(f'Unsupported model "{args.model}"')
actor.to(torch.cuda.current_device())
if args.model_path is not None:
state_dict = torch.load(args.model_path)
actor.load_state_dict(state_dict)
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'llama':
tokenizer = LlamaTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
tokenizer.eos_token = '<\s>'
tokenizer.pad_token = tokenizer.unk_token
else:
raise ValueError(f'Unsupported model "{args.model}"')
actor.eval()
input_ids = tokenizer.encode(args.input,
return_tensors='pt')\
.to(torch.cuda.current_device())
outputs = generate(actor,
input_ids,
max_length=args.max_length,
do_sample=True,
top_k=50,
top_p=0.95,
num_return_sequences=1)
output = tokenizer.batch_decode(outputs[0],
skip_special_tokens=True)
print(f"[Output]: {''.join(output)}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
# We suggest to use the pretrained model from HuggingFace, use pretrain to configure model
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--model_path', type=str, default=None)
parser.add_argument('--input', type=str, default='Question: How are you ? Answer:')
parser.add_argument('--max_length', type=int, default=100)
args = parser.parse_args()
eval(args)