diff --git a/applications/Chat/examples/train_prompts.py b/applications/Chat/examples/train_prompts.py index a584991cd..134f21f80 100644 --- a/applications/Chat/examples/train_prompts.py +++ b/applications/Chat/examples/train_prompts.py @@ -36,45 +36,45 @@ def main(args): if args.rm_path is not None: state_dict = torch.load(args.rm_path, map_location='cpu') - # configure model - if args.model == 'gpt2': - initial_model = GPTActor(pretrained=args.pretrain) - elif args.model == 'bloom': - initial_model = BLOOMActor(pretrained=args.pretrain) - elif args.model == 'opt': - initial_model = OPTActor(pretrained=args.pretrain) - elif args.model == 'llama': - initial_model = LlamaActor(pretrained=args.pretrain) - elif args.model == 'roberta': - initial_model = RoBERTaActor(pretrained=args.pretrain) - else: - raise ValueError(f'Unsupported actor model "{args.model}"') + with strategy.model_init_context(): + # configure model + if args.model == 'gpt2': + initial_model = GPTActor(pretrained=args.pretrain) + elif args.model == 'bloom': + initial_model = BLOOMActor(pretrained=args.pretrain) + elif args.model == 'opt': + initial_model = OPTActor(pretrained=args.pretrain) + elif args.model == 'llama': + initial_model = LlamaActor(pretrained=args.pretrain) + elif args.model == 'roberta': + initial_model = RoBERTaActor(pretrained=args.pretrain) + else: + raise ValueError(f'Unsupported actor model "{args.model}"') - if args.rm_model == None: - rm_model_name = args.model - else: - rm_model_name = args.rm_model - - if rm_model_name == 'gpt2': - reward_model = GPTRM(pretrained=args.rm_pretrain) - elif rm_model_name == 'bloom': - reward_model = BLOOMRM(pretrained=args.rm_pretrain) - elif rm_model_name == 'opt': - reward_model = OPTRM(pretrained=args.rm_pretrain) - elif rm_model_name == 'llama': - reward_model = LlamaRM(pretrained=args.rm_pretrain) - elif rm_model_name == 'roberta': - reward_model = RoBERTaRM(pretrained=args.rm_pretrain) - else: - raise ValueError(f'Unsupported reward model "{rm_model_name}"') + if args.rm_model == None: + rm_model_name = args.model + else: + rm_model_name = args.rm_model - if args.rm_path is not None: - reward_model.load_state_dict(state_dict) + if rm_model_name == 'gpt2': + reward_model = GPTRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'bloom': + reward_model = BLOOMRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'opt': + reward_model = OPTRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'llama': + reward_model = LlamaRM(pretrained=args.rm_pretrain) + elif rm_model_name == 'roberta': + reward_model = RoBERTaRM(pretrained=args.rm_pretrain) + else: + raise ValueError(f'Unsupported reward model "{rm_model_name}"') - initial_model.to(torch.float16).to(torch.cuda.current_device()) - reward_model.to(torch.float16).to(torch.cuda.current_device()) + if args.rm_path is not None: + reward_model.load_state_dict(state_dict) + + initial_model.to(torch.float16).to(torch.cuda.current_device()) + reward_model.to(torch.float16).to(torch.cuda.current_device()) - with strategy.model_init_context(): if args.model == 'gpt2': actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) elif args.model == 'bloom':