mirror of https://github.com/hpcaitech/ColossalAI
parent
0672b5afac
commit
68577fbc43
|
@ -38,19 +38,19 @@ def main(args):
|
|||
# configure model
|
||||
with strategy.model_init_context():
|
||||
if args.model == 'gpt2':
|
||||
actor = GPTActor().cuda()
|
||||
critic = GPTCritic().cuda()
|
||||
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
critic = GPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
elif args.model == 'bloom':
|
||||
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||
critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
elif args.model == 'opt':
|
||||
actor = OPTActor().cuda()
|
||||
critic = OPTCritic().cuda()
|
||||
actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
initial_model = deepcopy(actor).cuda()
|
||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
|
||||
initial_model = deepcopy(actor).to(torch.cuda.current_device())
|
||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(torch.cuda.current_device())
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
|
@ -114,12 +114,13 @@ def main(args):
|
|||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
|
||||
# save model checkpoint after fitting on only rank0
|
||||
# save model checkpoint after fitting
|
||||
strategy.save_model(actor, 'actor_checkpoint_dummy.pt', only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
strategy.save_optimizer(actor_optim,
|
||||
'actor_optim_checkpoint_dummy_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(actor_optim,
|
||||
'actor_optim_checkpoint_dummy_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -129,6 +130,7 @@ if __name__ == '__main__':
|
|||
default='naive')
|
||||
parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--num_episodes', type=int, default=50)
|
||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||
|
|
|
@ -32,19 +32,20 @@ def main(args):
|
|||
# configure model
|
||||
with strategy.model_init_context():
|
||||
if args.model == 'gpt2':
|
||||
actor = GPTActor().cuda()
|
||||
critic = GPTCritic().cuda()
|
||||
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
critic = GPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
elif args.model == 'bloom':
|
||||
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||
critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
||||
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
elif args.model == 'opt':
|
||||
actor = OPTActor(lora_rank=args.lora_rank).cuda()
|
||||
critic = OPTCritic(lora_rank=args.lora_rank).cuda()
|
||||
actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
initial_model = deepcopy(actor)
|
||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
|
||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(torch.cuda.current_device())
|
||||
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
|
@ -100,12 +101,13 @@ def main(args):
|
|||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
# save model checkpoint after fitting on only rank0
|
||||
# save model checkpoint after fitting
|
||||
strategy.save_model(actor, 'actor_checkpoint_prompts.pt', only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
strategy.save_optimizer(actor_optim,
|
||||
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(actor_optim,
|
||||
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -116,6 +118,7 @@ if __name__ == '__main__':
|
|||
default='naive')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--num_episodes', type=int, default=10)
|
||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||
|
|
Loading…
Reference in New Issue