|
|
|
@ -32,19 +32,20 @@ def main(args):
|
|
|
|
|
# configure model
|
|
|
|
|
with strategy.model_init_context():
|
|
|
|
|
if args.model == 'gpt2':
|
|
|
|
|
actor = GPTActor().cuda()
|
|
|
|
|
critic = GPTCritic().cuda()
|
|
|
|
|
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
|
|
|
critic = GPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
|
|
|
elif args.model == 'bloom':
|
|
|
|
|
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
|
|
|
|
critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
|
|
|
|
|
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
|
|
|
critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
|
|
|
elif args.model == 'opt':
|
|
|
|
|
actor = OPTActor(lora_rank=args.lora_rank).cuda()
|
|
|
|
|
critic = OPTCritic(lora_rank=args.lora_rank).cuda()
|
|
|
|
|
actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
|
|
|
critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f'Unsupported model "{args.model}"')
|
|
|
|
|
|
|
|
|
|
initial_model = deepcopy(actor)
|
|
|
|
|
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
|
|
|
|
|
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(torch.cuda.current_device())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# configure optimizer
|
|
|
|
|
if args.strategy.startswith('colossalai'):
|
|
|
|
@ -100,9 +101,10 @@ def main(args):
|
|
|
|
|
num_episodes=args.num_episodes,
|
|
|
|
|
max_timesteps=args.max_timesteps,
|
|
|
|
|
update_timesteps=args.update_timesteps)
|
|
|
|
|
# save model checkpoint after fitting on only rank0
|
|
|
|
|
# save model checkpoint after fitting
|
|
|
|
|
strategy.save_model(actor, 'actor_checkpoint_prompts.pt', only_rank0=True)
|
|
|
|
|
# save optimizer checkpoint on all ranks
|
|
|
|
|
if args.need_optim_ckpt:
|
|
|
|
|
strategy.save_optimizer(actor_optim,
|
|
|
|
|
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
|
|
|
|
only_rank0=False)
|
|
|
|
@ -116,6 +118,7 @@ if __name__ == '__main__':
|
|
|
|
|
default='naive')
|
|
|
|
|
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
|
|
|
|
parser.add_argument('--pretrain', type=str, default=None)
|
|
|
|
|
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
|
|
|
|
parser.add_argument('--num_episodes', type=int, default=10)
|
|
|
|
|
parser.add_argument('--max_timesteps', type=int, default=10)
|
|
|
|
|
parser.add_argument('--update_timesteps', type=int, default=10)
|
|
|
|
|