diff --git a/applications/ChatGPT/examples/README.md b/applications/ChatGPT/examples/README.md index c411c880b..bf3daf5ec 100644 --- a/applications/ChatGPT/examples/README.md +++ b/applications/ChatGPT/examples/README.md @@ -69,10 +69,13 @@ torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy ## Inference example(After Stage3) We support naive inference demo after training. ```shell -# inference -python inference.py --pretrain --model +# inference, using pretrain path to configure model +python inference.py --model_path --model --pretrain +# example +python inference.py --model_path ./actor_checkpoint_prompts.pt --pretrain bigscience/bloom-560m --model bloom ``` + #### data - [x] [rm-static](https://huggingface.co/datasets/Dahoas/rm-static) - [x] [hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) diff --git a/applications/ChatGPT/examples/inference.py b/applications/ChatGPT/examples/inference.py index a2682277d..239b6e19b 100644 --- a/applications/ChatGPT/examples/inference.py +++ b/applications/ChatGPT/examples/inference.py @@ -1,6 +1,6 @@ import argparse -import torch +import torch from chatgpt.nn import BLOOMActor, GPTActor, OPTActor from transformers import AutoTokenizer from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer @@ -9,18 +9,17 @@ from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer def eval(args): # configure model if args.model == 'gpt2': - actor = GPTActor().to(torch.cuda.current_device()) + actor = GPTActor(pretrained=args.pretrain).to(torch.cuda.current_device()) elif args.model == 'bloom': - actor = BLOOMActor().to(torch.cuda.current_device()) + actor = BLOOMActor(pretrained=args.pretrain).to(torch.cuda.current_device()) elif args.model == 'opt': - actor = OPTActor().to(torch.cuda.current_device()) + actor = OPTActor(pretrained=args.pretrain).to(torch.cuda.current_device()) else: raise ValueError(f'Unsupported model "{args.model}"') - state_dict = torch.load(args.pretrain) + state_dict = torch.load(args.model_path) actor.model.load_state_dict(state_dict) - - + # configure tokenizer if args.model == 'gpt2': tokenizer = GPT2Tokenizer.from_pretrained('gpt2') @@ -49,7 +48,9 @@ def eval(args): if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) + # We suggest to use the pretrained model from HuggingFace, use pretrain to configure model parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--model_path', type=str, default=None) parser.add_argument('--input', type=str, default='Question: How are you ? Answer:') parser.add_argument('--max_length', type=int, default=100) args = parser.parse_args()