[chatgpt]fix inference model load (#2988)

* fix lora bug

* polish

* fix lora gemini

* fix inference laod model bug
pull/2880/head^2
BlueRum 2 years ago committed by GitHub
parent 82503a96f2
commit e588703454
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -69,10 +69,13 @@ torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy
## Inference example(After Stage3)
We support naive inference demo after training.
```shell
# inference
python inference.py --pretrain <your actor model path> --model <your model type>
# inference, using pretrain path to configure model
python inference.py --model_path <your actor model path> --model <your model type> --pretrain <your pretrain model name/path>
# example
python inference.py --model_path ./actor_checkpoint_prompts.pt --pretrain bigscience/bloom-560m --model bloom
```
#### data
- [x] [rm-static](https://huggingface.co/datasets/Dahoas/rm-static)
- [x] [hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)

@ -1,6 +1,6 @@
import argparse
import torch
import torch
from chatgpt.nn import BLOOMActor, GPTActor, OPTActor
from transformers import AutoTokenizer
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
@ -9,18 +9,17 @@ from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
def eval(args):
# configure model
if args.model == 'gpt2':
actor = GPTActor().to(torch.cuda.current_device())
actor = GPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
elif args.model == 'bloom':
actor = BLOOMActor().to(torch.cuda.current_device())
actor = BLOOMActor(pretrained=args.pretrain).to(torch.cuda.current_device())
elif args.model == 'opt':
actor = OPTActor().to(torch.cuda.current_device())
actor = OPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
else:
raise ValueError(f'Unsupported model "{args.model}"')
state_dict = torch.load(args.pretrain)
state_dict = torch.load(args.model_path)
actor.model.load_state_dict(state_dict)
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
@ -49,7 +48,9 @@ def eval(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
# We suggest to use the pretrained model from HuggingFace, use pretrain to configure model
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--model_path', type=str, default=None)
parser.add_argument('--input', type=str, default='Question: How are you ? Answer:')
parser.add_argument('--max_length', type=int, default=100)
args = parser.parse_args()

Loading…
Cancel
Save