[chatgpt]Fix examples (#3116)

* fix train_dummy

* fix train-prompts
pull/3121/head^2
BlueRum 2 years ago committed by GitHub
parent 0672b5afac
commit 68577fbc43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -38,19 +38,19 @@ def main(args):
# configure model # configure model
with strategy.model_init_context(): with strategy.model_init_context():
if args.model == 'gpt2': if args.model == 'gpt2':
actor = GPTActor().cuda() actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
critic = GPTCritic().cuda() critic = GPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'bloom': elif args.model == 'bloom':
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'opt': elif args.model == 'opt':
actor = OPTActor().cuda() actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
critic = OPTCritic().cuda() critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
initial_model = deepcopy(actor).cuda() initial_model = deepcopy(actor).to(torch.cuda.current_device())
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda() reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(torch.cuda.current_device())
# configure optimizer # configure optimizer
if args.strategy.startswith('colossalai'): if args.strategy.startswith('colossalai'):
@ -114,9 +114,10 @@ def main(args):
max_timesteps=args.max_timesteps, max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps) update_timesteps=args.update_timesteps)
# save model checkpoint after fitting on only rank0 # save model checkpoint after fitting
strategy.save_model(actor, 'actor_checkpoint_dummy.pt', only_rank0=True) strategy.save_model(actor, 'actor_checkpoint_dummy.pt', only_rank0=True)
# save optimizer checkpoint on all ranks # save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(actor_optim, strategy.save_optimizer(actor_optim,
'actor_optim_checkpoint_dummy_%d.pt' % (torch.cuda.current_device()), 'actor_optim_checkpoint_dummy_%d.pt' % (torch.cuda.current_device()),
only_rank0=False) only_rank0=False)
@ -129,6 +130,7 @@ if __name__ == '__main__':
default='naive') default='naive')
parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt']) parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt'])
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--num_episodes', type=int, default=50) parser.add_argument('--num_episodes', type=int, default=50)
parser.add_argument('--max_timesteps', type=int, default=10) parser.add_argument('--max_timesteps', type=int, default=10)
parser.add_argument('--update_timesteps', type=int, default=10) parser.add_argument('--update_timesteps', type=int, default=10)

@ -32,19 +32,20 @@ def main(args):
# configure model # configure model
with strategy.model_init_context(): with strategy.model_init_context():
if args.model == 'gpt2': if args.model == 'gpt2':
actor = GPTActor().cuda() actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
critic = GPTCritic().cuda() critic = GPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'bloom': elif args.model == 'bloom':
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
elif args.model == 'opt': elif args.model == 'opt':
actor = OPTActor(lora_rank=args.lora_rank).cuda() actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
critic = OPTCritic(lora_rank=args.lora_rank).cuda() critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
else: else:
raise ValueError(f'Unsupported model "{args.model}"') raise ValueError(f'Unsupported model "{args.model}"')
initial_model = deepcopy(actor) initial_model = deepcopy(actor)
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda() reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(torch.cuda.current_device())
# configure optimizer # configure optimizer
if args.strategy.startswith('colossalai'): if args.strategy.startswith('colossalai'):
@ -100,9 +101,10 @@ def main(args):
num_episodes=args.num_episodes, num_episodes=args.num_episodes,
max_timesteps=args.max_timesteps, max_timesteps=args.max_timesteps,
update_timesteps=args.update_timesteps) update_timesteps=args.update_timesteps)
# save model checkpoint after fitting on only rank0 # save model checkpoint after fitting
strategy.save_model(actor, 'actor_checkpoint_prompts.pt', only_rank0=True) strategy.save_model(actor, 'actor_checkpoint_prompts.pt', only_rank0=True)
# save optimizer checkpoint on all ranks # save optimizer checkpoint on all ranks
if args.need_optim_ckpt:
strategy.save_optimizer(actor_optim, strategy.save_optimizer(actor_optim,
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
only_rank0=False) only_rank0=False)
@ -116,6 +118,7 @@ if __name__ == '__main__':
default='naive') default='naive')
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
parser.add_argument('--num_episodes', type=int, default=10) parser.add_argument('--num_episodes', type=int, default=10)
parser.add_argument('--max_timesteps', type=int, default=10) parser.add_argument('--max_timesteps', type=int, default=10)
parser.add_argument('--update_timesteps', type=int, default=10) parser.add_argument('--update_timesteps', type=int, default=10)

Loading…
Cancel
Save