mirror of https://github.com/hpcaitech/ColossalAI
[chatgpt]fix inference model load (#2988)
* fix lora bug * polish * fix lora gemini * fix inference laod model bugpull/2880/head^2
parent
82503a96f2
commit
e588703454
|
@ -69,10 +69,13 @@ torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy
|
|||
## Inference example(After Stage3)
|
||||
We support naive inference demo after training.
|
||||
```shell
|
||||
# inference
|
||||
python inference.py --pretrain <your actor model path> --model <your model type>
|
||||
# inference, using pretrain path to configure model
|
||||
python inference.py --model_path <your actor model path> --model <your model type> --pretrain <your pretrain model name/path>
|
||||
# example
|
||||
python inference.py --model_path ./actor_checkpoint_prompts.pt --pretrain bigscience/bloom-560m --model bloom
|
||||
```
|
||||
|
||||
|
||||
#### data
|
||||
- [x] [rm-static](https://huggingface.co/datasets/Dahoas/rm-static)
|
||||
- [x] [hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import argparse
|
||||
import torch
|
||||
|
||||
import torch
|
||||
from chatgpt.nn import BLOOMActor, GPTActor, OPTActor
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
@ -9,18 +9,17 @@ from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
|||
def eval(args):
|
||||
# configure model
|
||||
if args.model == 'gpt2':
|
||||
actor = GPTActor().to(torch.cuda.current_device())
|
||||
actor = GPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
|
||||
elif args.model == 'bloom':
|
||||
actor = BLOOMActor().to(torch.cuda.current_device())
|
||||
actor = BLOOMActor(pretrained=args.pretrain).to(torch.cuda.current_device())
|
||||
elif args.model == 'opt':
|
||||
actor = OPTActor().to(torch.cuda.current_device())
|
||||
actor = OPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
state_dict = torch.load(args.pretrain)
|
||||
state_dict = torch.load(args.model_path)
|
||||
actor.model.load_state_dict(state_dict)
|
||||
|
||||
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
|
@ -49,7 +48,9 @@ def eval(args):
|
|||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||||
# We suggest to use the pretrained model from HuggingFace, use pretrain to configure model
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--model_path', type=str, default=None)
|
||||
parser.add_argument('--input', type=str, default='Question: How are you ? Answer:')
|
||||
parser.add_argument('--max_length', type=int, default=100)
|
||||
args = parser.parse_args()
|
||||
|
|
Loading…
Reference in New Issue