mirror of https://github.com/hpcaitech/ColossalAI
[chatgpt]add inference example (#2944)
* [chatgpt] support inference example * Create inference.sh * Update README.md * Delete inference.sh * Update inference.pypull/2951/head
parent
47fb214b3b
commit
489a9566af
|
@ -6,7 +6,21 @@
|
|||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Train with dummy prompt data
|
||||
## Train the reward model (Stage 2)
|
||||
We use [rm-static](https://huggingface.co/datasets/Dahoas/rm-static) as dataset to train our reward model. It is a dataset of chosen & rejected response of the same prompt.
|
||||
|
||||
You can download the dataset from huggingface automatically.
|
||||
|
||||
Use these code to train your reward model.
|
||||
|
||||
```shell
|
||||
# Naive reward model training
|
||||
python train_reward_model.py --pretrain <your model path>
|
||||
# if to use LoRA
|
||||
python train_reward_model.py --pretrain <your model path> --lora_rank 16
|
||||
```
|
||||
|
||||
## Train with dummy prompt data (Stage 3)
|
||||
|
||||
This script supports 3 strategies:
|
||||
|
||||
|
@ -33,7 +47,7 @@ torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy ddp
|
|||
torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy colossalai
|
||||
```
|
||||
|
||||
## Train with real prompt data
|
||||
## Train with real prompt data (Stage 3)
|
||||
|
||||
We use [awesome-chatgpt-prompts](https://huggingface.co/datasets/fka/awesome-chatgpt-prompts) as example dataset. It is a small dataset with hundreds of prompts.
|
||||
|
||||
|
@ -52,18 +66,11 @@ torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy
|
|||
torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai
|
||||
```
|
||||
|
||||
## Train the reward model
|
||||
We use [rm-static](https://huggingface.co/datasets/Dahoas/rm-static) as dataset to train our reward model. It is a dataset of chosen & rejected response of the same prompt.
|
||||
|
||||
You can download the dataset from huggingface automatically.
|
||||
|
||||
Use these code to train your reward model.
|
||||
|
||||
## Inference example(After Stage3)
|
||||
We support naive inference demo after training.
|
||||
```shell
|
||||
# Naive reward model training
|
||||
python train_reward_model.py --pretrain <your model path>
|
||||
# if to use LoRA
|
||||
python train_reward_model.py --pretrain <your model path> --lora_rank 16
|
||||
# inference
|
||||
python inference_actor.py --pretrain <your actor model path> --model <your model type>
|
||||
```
|
||||
|
||||
## Support Model
|
||||
|
@ -91,8 +98,8 @@ python train_reward_model.py --pretrain <your model path> --lora_rank 16
|
|||
### BLOOM
|
||||
- [x] [BLOOM-560m](https://huggingface.co/bigscience/bloom-560m)
|
||||
- [x] [BLOOM-1b1](https://huggingface.co/bigscience/bloom-1b1)
|
||||
- [ ] [BLOOM-3b](https://huggingface.co/bigscience/bloom-3b)
|
||||
- [ ] [BLOOM-7b](https://huggingface.co/bigscience/bloomz-7b1)
|
||||
- [x] [BLOOM-3b](https://huggingface.co/bigscience/bloom-3b)
|
||||
- [x] [BLOOM-7b](https://huggingface.co/bigscience/bloomz-7b1)
|
||||
- [ ] BLOOM-175b
|
||||
|
||||
### OPT
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
import argparse
|
||||
import torch
|
||||
|
||||
from chatgpt.nn import BLOOMActor, GPTActor, OPTActor
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
|
||||
def eval(args):
|
||||
# configure model
|
||||
if args.model == 'gpt2':
|
||||
model = GPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
|
||||
elif args.model == 'bloom':
|
||||
model = BLOOMActor(pretrained=args.pretrain).to(torch.cuda.current_device())
|
||||
elif args.model == 'opt':
|
||||
model = OPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
model.eval()
|
||||
input = args.input
|
||||
input_ids = tokenizer.encode(input, return_tensors='pt').to(torch.cuda.current_device())
|
||||
outputs = model.generate(input_ids,
|
||||
max_length=args.max_length,
|
||||
do_sample=True,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
num_return_sequences=1)
|
||||
output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
|
||||
print(output)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--input', type=str, default='Q: How are you ? A:')
|
||||
parser.add_argument('--max_length', type=int, default=100)
|
||||
args = parser.parse_args()
|
||||
eval(args)
|
Loading…
Reference in New Issue