|
|
|
@ -32,19 +32,20 @@ def main(args):
|
|
|
|
|
# configure model |
|
|
|
|
with strategy.model_init_context(): |
|
|
|
|
if args.model == 'gpt2': |
|
|
|
|
actor = GPTActor().cuda() |
|
|
|
|
critic = GPTCritic().cuda() |
|
|
|
|
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) |
|
|
|
|
critic = GPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) |
|
|
|
|
elif args.model == 'bloom': |
|
|
|
|
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() |
|
|
|
|
critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() |
|
|
|
|
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) |
|
|
|
|
critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) |
|
|
|
|
elif args.model == 'opt': |
|
|
|
|
actor = OPTActor(lora_rank=args.lora_rank).cuda() |
|
|
|
|
critic = OPTCritic(lora_rank=args.lora_rank).cuda() |
|
|
|
|
actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) |
|
|
|
|
critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) |
|
|
|
|
else: |
|
|
|
|
raise ValueError(f'Unsupported model "{args.model}"') |
|
|
|
|
|
|
|
|
|
initial_model = deepcopy(actor) |
|
|
|
|
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda() |
|
|
|
|
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(torch.cuda.current_device()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# configure optimizer |
|
|
|
|
if args.strategy.startswith('colossalai'): |
|
|
|
@ -100,12 +101,13 @@ def main(args):
|
|
|
|
|
num_episodes=args.num_episodes, |
|
|
|
|
max_timesteps=args.max_timesteps, |
|
|
|
|
update_timesteps=args.update_timesteps) |
|
|
|
|
# save model checkpoint after fitting on only rank0 |
|
|
|
|
# save model checkpoint after fitting |
|
|
|
|
strategy.save_model(actor, 'actor_checkpoint_prompts.pt', only_rank0=True) |
|
|
|
|
# save optimizer checkpoint on all ranks |
|
|
|
|
strategy.save_optimizer(actor_optim, |
|
|
|
|
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), |
|
|
|
|
only_rank0=False) |
|
|
|
|
if args.need_optim_ckpt: |
|
|
|
|
strategy.save_optimizer(actor_optim, |
|
|
|
|
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), |
|
|
|
|
only_rank0=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
@ -116,6 +118,7 @@ if __name__ == '__main__':
|
|
|
|
|
default='naive') |
|
|
|
|
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) |
|
|
|
|
parser.add_argument('--pretrain', type=str, default=None) |
|
|
|
|
parser.add_argument('--need_optim_ckpt', type=bool, default=False) |
|
|
|
|
parser.add_argument('--num_episodes', type=int, default=10) |
|
|
|
|
parser.add_argument('--max_timesteps', type=int, default=10) |
|
|
|
|
parser.add_argument('--update_timesteps', type=int, default=10) |
|
|
|
|