From 82149e9d1b0d9e1e9eeb643af7c7e19fbf503ee4 Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Thu, 2 Mar 2023 16:18:33 +0800 Subject: [PATCH] [chatgpt] fix inference demo loading bug (#2969) * [chatgpt] fix inference demo loading bug * polish --- applications/ChatGPT/examples/inference.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/applications/ChatGPT/examples/inference.py b/applications/ChatGPT/examples/inference.py index ba055d81f..a2682277d 100644 --- a/applications/ChatGPT/examples/inference.py +++ b/applications/ChatGPT/examples/inference.py @@ -9,30 +9,34 @@ from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer def eval(args): # configure model if args.model == 'gpt2': - model = GPTActor(pretrained=args.pretrain).to(torch.cuda.current_device()) + actor = GPTActor().to(torch.cuda.current_device()) elif args.model == 'bloom': - model = BLOOMActor(pretrained=args.pretrain).to(torch.cuda.current_device()) + actor = BLOOMActor().to(torch.cuda.current_device()) elif args.model == 'opt': - model = OPTActor(pretrained=args.pretrain).to(torch.cuda.current_device()) + actor = OPTActor().to(torch.cuda.current_device()) else: raise ValueError(f'Unsupported model "{args.model}"') + state_dict = torch.load(args.pretrain) + actor.model.load_state_dict(state_dict) + + # configure tokenizer if args.model == 'gpt2': tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer.pad_token = tokenizer.eos_token elif args.model == 'bloom': - tokenizer = AutoTokenizer.from_pretrained(args.pretrain) + tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m') tokenizer.pad_token = tokenizer.eos_token elif args.model == 'opt': - tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m') else: raise ValueError(f'Unsupported model "{args.model}"') - model.eval() + actor.eval() input = args.input input_ids = tokenizer.encode(input, return_tensors='pt').to(torch.cuda.current_device()) - outputs = model.generate(input_ids, + outputs = actor.generate(input_ids, max_length=args.max_length, do_sample=True, top_k=50, @@ -46,7 +50,7 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) parser.add_argument('--pretrain', type=str, default=None) - parser.add_argument('--input', type=str, default='Q: How are you ? A:') + parser.add_argument('--input', type=str, default='Question: How are you ? Answer:') parser.add_argument('--max_length', type=int, default=100) args = parser.parse_args() eval(args)