diff --git a/applications/ChatGPT/examples/README.md b/applications/ChatGPT/examples/README.md index 5f9d8698d..e5522f087 100644 --- a/applications/ChatGPT/examples/README.md +++ b/applications/ChatGPT/examples/README.md @@ -6,7 +6,21 @@ pip install -r requirements.txt ``` -## Train with dummy prompt data +## Train the reward model (Stage 2) +We use [rm-static](https://huggingface.co/datasets/Dahoas/rm-static) as dataset to train our reward model. It is a dataset of chosen & rejected response of the same prompt. + +You can download the dataset from huggingface automatically. + +Use these code to train your reward model. + +```shell +# Naive reward model training +python train_reward_model.py --pretrain +# if to use LoRA +python train_reward_model.py --pretrain --lora_rank 16 +``` + +## Train with dummy prompt data (Stage 3) This script supports 3 strategies: @@ -33,7 +47,7 @@ torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy ddp torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy colossalai ``` -## Train with real prompt data +## Train with real prompt data (Stage 3) We use [awesome-chatgpt-prompts](https://huggingface.co/datasets/fka/awesome-chatgpt-prompts) as example dataset. It is a small dataset with hundreds of prompts. @@ -52,18 +66,11 @@ torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai ``` -## Train the reward model -We use [rm-static](https://huggingface.co/datasets/Dahoas/rm-static) as dataset to train our reward model. It is a dataset of chosen & rejected response of the same prompt. - -You can download the dataset from huggingface automatically. - -Use these code to train your reward model. - +## Inference example(After Stage3) +We support naive inference demo after training. ```shell -# Naive reward model training -python train_reward_model.py --pretrain -# if to use LoRA -python train_reward_model.py --pretrain --lora_rank 16 +# inference +python inference_actor.py --pretrain --model ``` ## Support Model @@ -91,8 +98,8 @@ python train_reward_model.py --pretrain --lora_rank 16 ### BLOOM - [x] [BLOOM-560m](https://huggingface.co/bigscience/bloom-560m) - [x] [BLOOM-1b1](https://huggingface.co/bigscience/bloom-1b1) -- [ ] [BLOOM-3b](https://huggingface.co/bigscience/bloom-3b) -- [ ] [BLOOM-7b](https://huggingface.co/bigscience/bloomz-7b1) +- [x] [BLOOM-3b](https://huggingface.co/bigscience/bloom-3b) +- [x] [BLOOM-7b](https://huggingface.co/bigscience/bloomz-7b1) - [ ] BLOOM-175b ### OPT diff --git a/applications/ChatGPT/examples/inference.py b/applications/ChatGPT/examples/inference.py new file mode 100644 index 000000000..ba055d81f --- /dev/null +++ b/applications/ChatGPT/examples/inference.py @@ -0,0 +1,52 @@ +import argparse +import torch + +from chatgpt.nn import BLOOMActor, GPTActor, OPTActor +from transformers import AutoTokenizer +from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer + + +def eval(args): + # configure model + if args.model == 'gpt2': + model = GPTActor(pretrained=args.pretrain).to(torch.cuda.current_device()) + elif args.model == 'bloom': + model = BLOOMActor(pretrained=args.pretrain).to(torch.cuda.current_device()) + elif args.model == 'opt': + model = OPTActor(pretrained=args.pretrain).to(torch.cuda.current_device()) + else: + raise ValueError(f'Unsupported model "{args.model}"') + + # configure tokenizer + if args.model == 'gpt2': + tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'bloom': + tokenizer = AutoTokenizer.from_pretrained(args.pretrain) + tokenizer.pad_token = tokenizer.eos_token + elif args.model == 'opt': + tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") + else: + raise ValueError(f'Unsupported model "{args.model}"') + + model.eval() + input = args.input + input_ids = tokenizer.encode(input, return_tensors='pt').to(torch.cuda.current_device()) + outputs = model.generate(input_ids, + max_length=args.max_length, + do_sample=True, + top_k=50, + top_p=0.95, + num_return_sequences=1) + output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True) + print(output) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) + parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--input', type=str, default='Q: How are you ? A:') + parser.add_argument('--max_length', type=int, default=100) + args = parser.parse_args() + eval(args)