mirror of https://github.com/hpcaitech/ColossalAI
[chat] fix train_prompts.py gemini strategy bug (#3666)
* fix gemini strategy bug * add comment * add comment * better solutionpull/3699/head
parent
d556648885
commit
2da5d81dec
|
@ -36,6 +36,7 @@ def main(args):
|
||||||
if args.rm_path is not None:
|
if args.rm_path is not None:
|
||||||
state_dict = torch.load(args.rm_path, map_location='cpu')
|
state_dict = torch.load(args.rm_path, map_location='cpu')
|
||||||
|
|
||||||
|
with strategy.model_init_context():
|
||||||
# configure model
|
# configure model
|
||||||
if args.model == 'gpt2':
|
if args.model == 'gpt2':
|
||||||
initial_model = GPTActor(pretrained=args.pretrain)
|
initial_model = GPTActor(pretrained=args.pretrain)
|
||||||
|
@ -74,7 +75,6 @@ def main(args):
|
||||||
initial_model.to(torch.float16).to(torch.cuda.current_device())
|
initial_model.to(torch.float16).to(torch.cuda.current_device())
|
||||||
reward_model.to(torch.float16).to(torch.cuda.current_device())
|
reward_model.to(torch.float16).to(torch.cuda.current_device())
|
||||||
|
|
||||||
with strategy.model_init_context():
|
|
||||||
if args.model == 'gpt2':
|
if args.model == 'gpt2':
|
||||||
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||||
elif args.model == 'bloom':
|
elif args.model == 'bloom':
|
||||||
|
|
Loading…
Reference in New Issue