|
|
|
@ -36,45 +36,45 @@ def main(args):
|
|
|
|
|
if args.rm_path is not None:
|
|
|
|
|
state_dict = torch.load(args.rm_path, map_location='cpu')
|
|
|
|
|
|
|
|
|
|
# configure model
|
|
|
|
|
if args.model == 'gpt2':
|
|
|
|
|
initial_model = GPTActor(pretrained=args.pretrain)
|
|
|
|
|
elif args.model == 'bloom':
|
|
|
|
|
initial_model = BLOOMActor(pretrained=args.pretrain)
|
|
|
|
|
elif args.model == 'opt':
|
|
|
|
|
initial_model = OPTActor(pretrained=args.pretrain)
|
|
|
|
|
elif args.model == 'llama':
|
|
|
|
|
initial_model = LlamaActor(pretrained=args.pretrain)
|
|
|
|
|
elif args.model == 'roberta':
|
|
|
|
|
initial_model = RoBERTaActor(pretrained=args.pretrain)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f'Unsupported actor model "{args.model}"')
|
|
|
|
|
with strategy.model_init_context():
|
|
|
|
|
# configure model
|
|
|
|
|
if args.model == 'gpt2':
|
|
|
|
|
initial_model = GPTActor(pretrained=args.pretrain)
|
|
|
|
|
elif args.model == 'bloom':
|
|
|
|
|
initial_model = BLOOMActor(pretrained=args.pretrain)
|
|
|
|
|
elif args.model == 'opt':
|
|
|
|
|
initial_model = OPTActor(pretrained=args.pretrain)
|
|
|
|
|
elif args.model == 'llama':
|
|
|
|
|
initial_model = LlamaActor(pretrained=args.pretrain)
|
|
|
|
|
elif args.model == 'roberta':
|
|
|
|
|
initial_model = RoBERTaActor(pretrained=args.pretrain)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f'Unsupported actor model "{args.model}"')
|
|
|
|
|
|
|
|
|
|
if args.rm_model == None:
|
|
|
|
|
rm_model_name = args.model
|
|
|
|
|
else:
|
|
|
|
|
rm_model_name = args.rm_model
|
|
|
|
|
|
|
|
|
|
if rm_model_name == 'gpt2':
|
|
|
|
|
reward_model = GPTRM(pretrained=args.rm_pretrain)
|
|
|
|
|
elif rm_model_name == 'bloom':
|
|
|
|
|
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
|
|
|
|
|
elif rm_model_name == 'opt':
|
|
|
|
|
reward_model = OPTRM(pretrained=args.rm_pretrain)
|
|
|
|
|
elif rm_model_name == 'llama':
|
|
|
|
|
reward_model = LlamaRM(pretrained=args.rm_pretrain)
|
|
|
|
|
elif rm_model_name == 'roberta':
|
|
|
|
|
reward_model = RoBERTaRM(pretrained=args.rm_pretrain)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
|
|
|
|
|
if args.rm_model == None:
|
|
|
|
|
rm_model_name = args.model
|
|
|
|
|
else:
|
|
|
|
|
rm_model_name = args.rm_model
|
|
|
|
|
|
|
|
|
|
if args.rm_path is not None:
|
|
|
|
|
reward_model.load_state_dict(state_dict)
|
|
|
|
|
if rm_model_name == 'gpt2':
|
|
|
|
|
reward_model = GPTRM(pretrained=args.rm_pretrain)
|
|
|
|
|
elif rm_model_name == 'bloom':
|
|
|
|
|
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
|
|
|
|
|
elif rm_model_name == 'opt':
|
|
|
|
|
reward_model = OPTRM(pretrained=args.rm_pretrain)
|
|
|
|
|
elif rm_model_name == 'llama':
|
|
|
|
|
reward_model = LlamaRM(pretrained=args.rm_pretrain)
|
|
|
|
|
elif rm_model_name == 'roberta':
|
|
|
|
|
reward_model = RoBERTaRM(pretrained=args.rm_pretrain)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
|
|
|
|
|
|
|
|
|
|
initial_model.to(torch.float16).to(torch.cuda.current_device())
|
|
|
|
|
reward_model.to(torch.float16).to(torch.cuda.current_device())
|
|
|
|
|
if args.rm_path is not None:
|
|
|
|
|
reward_model.load_state_dict(state_dict)
|
|
|
|
|
|
|
|
|
|
initial_model.to(torch.float16).to(torch.cuda.current_device())
|
|
|
|
|
reward_model.to(torch.float16).to(torch.cuda.current_device())
|
|
|
|
|
|
|
|
|
|
with strategy.model_init_context():
|
|
|
|
|
if args.model == 'gpt2':
|
|
|
|
|
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
|
|
|
|
elif args.model == 'bloom':
|
|
|
|
|