[chatgpt]add inference example (#2944)

* [chatgpt] support inference example

* Create inference.sh

* Update README.md

* Delete inference.sh

* Update inference.py
pull/2951/head
BlueRum 2023-03-01 13:39:39 +08:00 committed by GitHub
parent 47fb214b3b
commit 489a9566af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 74 additions and 15 deletions

View File

@ -6,7 +6,21 @@
pip install -r requirements.txt
```
## Train with dummy prompt data
## Train the reward model (Stage 2)
We use [rm-static](https://huggingface.co/datasets/Dahoas/rm-static) as dataset to train our reward model. It is a dataset of chosen & rejected response of the same prompt.
You can download the dataset from huggingface automatically.
Use these code to train your reward model.
```shell
# Naive reward model training
python train_reward_model.py --pretrain <your model path>
# if to use LoRA
python train_reward_model.py --pretrain <your model path> --lora_rank 16
```
## Train with dummy prompt data (Stage 3)
This script supports 3 strategies:
@ -33,7 +47,7 @@ torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy ddp
torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy colossalai
```
## Train with real prompt data
## Train with real prompt data (Stage 3)
We use [awesome-chatgpt-prompts](https://huggingface.co/datasets/fka/awesome-chatgpt-prompts) as example dataset. It is a small dataset with hundreds of prompts.
@ -52,18 +66,11 @@ torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy
torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai
```
## Train the reward model
We use [rm-static](https://huggingface.co/datasets/Dahoas/rm-static) as dataset to train our reward model. It is a dataset of chosen & rejected response of the same prompt.
You can download the dataset from huggingface automatically.
Use these code to train your reward model.
## Inference example(After Stage3)
We support naive inference demo after training.
```shell
# Naive reward model training
python train_reward_model.py --pretrain <your model path>
# if to use LoRA
python train_reward_model.py --pretrain <your model path> --lora_rank 16
# inference
python inference_actor.py --pretrain <your actor model path> --model <your model type>
```
## Support Model
@ -91,8 +98,8 @@ python train_reward_model.py --pretrain <your model path> --lora_rank 16
### BLOOM
- [x] [BLOOM-560m](https://huggingface.co/bigscience/bloom-560m)
- [x] [BLOOM-1b1](https://huggingface.co/bigscience/bloom-1b1)
- [ ] [BLOOM-3b](https://huggingface.co/bigscience/bloom-3b)
- [ ] [BLOOM-7b](https://huggingface.co/bigscience/bloomz-7b1)
- [x] [BLOOM-3b](https://huggingface.co/bigscience/bloom-3b)
- [x] [BLOOM-7b](https://huggingface.co/bigscience/bloomz-7b1)
- [ ] BLOOM-175b
### OPT

View File

@ -0,0 +1,52 @@
import argparse
import torch
from chatgpt.nn import BLOOMActor, GPTActor, OPTActor
from transformers import AutoTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
def eval(args):
# configure model
if args.model == 'gpt2':
model = GPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
elif args.model == 'bloom':
model = BLOOMActor(pretrained=args.pretrain).to(torch.cuda.current_device())
elif args.model == 'opt':
model = OPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
else:
raise ValueError(f'Unsupported model "{args.model}"')
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
else:
raise ValueError(f'Unsupported model "{args.model}"')
model.eval()
input = args.input
input_ids = tokenizer.encode(input, return_tensors='pt').to(torch.cuda.current_device())
outputs = model.generate(input_ids,
max_length=args.max_length,
do_sample=True,
top_k=50,
top_p=0.95,
num_return_sequences=1)
output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
print(output)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--input', type=str, default='Q: How are you ? A:')
parser.add_argument('--max_length', type=int, default=100)
args = parser.parse_args()
eval(args)