From 68577fbc4399b0e8333ad958959ac09e5c54033d Mon Sep 17 00:00:00 2001 From: BlueRum <70618399+ht-zhou@users.noreply.github.com> Date: Mon, 13 Mar 2023 11:12:22 +0800 Subject: [PATCH] [chatgpt]Fix examples (#3116) * fix train_dummy * fix train-prompts --- applications/ChatGPT/examples/train_dummy.py | 26 ++++++++++--------- .../ChatGPT/examples/train_prompts.py | 25 ++++++++++-------- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/applications/ChatGPT/examples/train_dummy.py b/applications/ChatGPT/examples/train_dummy.py index 27ee7f0f1..4c81f2f72 100644 --- a/applications/ChatGPT/examples/train_dummy.py +++ b/applications/ChatGPT/examples/train_dummy.py @@ -38,19 +38,19 @@ def main(args): # configure model with strategy.model_init_context(): if args.model == 'gpt2': - actor = GPTActor().cuda() - critic = GPTCritic().cuda() + actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + critic = GPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) elif args.model == 'bloom': - actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() - critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() + actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) elif args.model == 'opt': - actor = OPTActor().cuda() - critic = OPTCritic().cuda() + actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) else: raise ValueError(f'Unsupported model "{args.model}"') - initial_model = deepcopy(actor).cuda() - reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda() + initial_model = deepcopy(actor).to(torch.cuda.current_device()) + reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(torch.cuda.current_device()) # configure optimizer if args.strategy.startswith('colossalai'): @@ -114,12 +114,13 @@ def main(args): max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) - # save model checkpoint after fitting on only rank0 + # save model checkpoint after fitting strategy.save_model(actor, 'actor_checkpoint_dummy.pt', only_rank0=True) # save optimizer checkpoint on all ranks - strategy.save_optimizer(actor_optim, - 'actor_optim_checkpoint_dummy_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) + if args.need_optim_ckpt: + strategy.save_optimizer(actor_optim, + 'actor_optim_checkpoint_dummy_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) if __name__ == '__main__': @@ -129,6 +130,7 @@ if __name__ == '__main__': default='naive') parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt']) parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--need_optim_ckpt', type=bool, default=False) parser.add_argument('--num_episodes', type=int, default=50) parser.add_argument('--max_timesteps', type=int, default=10) parser.add_argument('--update_timesteps', type=int, default=10) diff --git a/applications/ChatGPT/examples/train_prompts.py b/applications/ChatGPT/examples/train_prompts.py index 576685234..49f0e2c4a 100644 --- a/applications/ChatGPT/examples/train_prompts.py +++ b/applications/ChatGPT/examples/train_prompts.py @@ -32,19 +32,20 @@ def main(args): # configure model with strategy.model_init_context(): if args.model == 'gpt2': - actor = GPTActor().cuda() - critic = GPTCritic().cuda() + actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + critic = GPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) elif args.model == 'bloom': - actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() - critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() + actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) elif args.model == 'opt': - actor = OPTActor(lora_rank=args.lora_rank).cuda() - critic = OPTCritic(lora_rank=args.lora_rank).cuda() + actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) + critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device()) else: raise ValueError(f'Unsupported model "{args.model}"') initial_model = deepcopy(actor) - reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda() + reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(torch.cuda.current_device()) + # configure optimizer if args.strategy.startswith('colossalai'): @@ -100,12 +101,13 @@ def main(args): num_episodes=args.num_episodes, max_timesteps=args.max_timesteps, update_timesteps=args.update_timesteps) - # save model checkpoint after fitting on only rank0 + # save model checkpoint after fitting strategy.save_model(actor, 'actor_checkpoint_prompts.pt', only_rank0=True) # save optimizer checkpoint on all ranks - strategy.save_optimizer(actor_optim, - 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), - only_rank0=False) + if args.need_optim_ckpt: + strategy.save_optimizer(actor_optim, + 'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()), + only_rank0=False) if __name__ == '__main__': @@ -116,6 +118,7 @@ if __name__ == '__main__': default='naive') parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt']) parser.add_argument('--pretrain', type=str, default=None) + parser.add_argument('--need_optim_ckpt', type=bool, default=False) parser.add_argument('--num_episodes', type=int, default=10) parser.add_argument('--max_timesteps', type=int, default=10) parser.add_argument('--update_timesteps', type=int, default=10)