import argparse import loralib as lora import torch from chatgpt.dataset import RewardDataset from chatgpt.nn import BLOOMRM, GPTRM, OPTRM from chatgpt.trainer import RewardModelTrainer from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy from datasets import load_dataset from torch.optim import Adam from transformers import AutoTokenizer, BloomTokenizerFast from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer from colossalai.nn.optimizer import HybridAdam def train(args): # configure strategy if args.strategy == 'naive': strategy = NaiveStrategy() elif args.strategy == 'ddp': strategy = DDPStrategy() elif args.strategy == 'colossalai_gemini': strategy = ColossalAIStrategy(stage=3, placement_policy='cuda') elif args.strategy == 'colossalai_zero2': strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') else: raise ValueError(f'Unsupported strategy "{args.strategy}"') # configure model with strategy.model_init_context(): if args.model == 'bloom': model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() elif args.model == 'opt': model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() elif args.model == 'gpt2': model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda() else: raise ValueError(f'Unsupported model "{args.model}"') # configure tokenizer if args.model == 'gpt2': tokenizer = GPT2Tokenizer.from_pretrained('gpt2') tokenizer.pad_token = tokenizer.eos_token elif args.model == 'bloom': tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) tokenizer.pad_token = tokenizer.eos_token elif args.model == 'opt': tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") else: raise ValueError(f'Unsupported model "{args.model}"') tokenizer.pad_token = tokenizer.eos_token max_len = 512 # configure optimizer if args.strategy.startswith('colossalai'): optim = HybridAdam(model.parameters(), lr=5e-5) else: optim = Adam(model.parameters(), lr=5e-5) # prepare for data and dataset data = load_dataset(args.dataset) train_data = data["train"].select(range(100)) eval_data = data['test'].select(range(5)) train_dataset = RewardDataset(train_data, tokenizer, max_len) eval_dataset = RewardDataset(eval_data, tokenizer, max_len) # batch_size here is expected to be C(k,2), k means # response of each prompt # be limited with the format of dataset 'Dahoas/rm-static', we'd better use batch_size as 1 trainer = RewardModelTrainer(model=model, strategy=strategy, optim=optim, train_dataset=train_dataset, eval_dataset=eval_dataset, batch_size=args.batch_size, max_epochs=args.max_epochs) trainer.fit(use_lora=args.lora_rank) # save model checkpoint after fitting on only rank0 strategy.save_model(model, 'rm_checkpoint.pt', only_rank0=True) # save optimizer checkpoint on all ranks strategy.save_optimizer(optim, 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--strategy', choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], default='naive') parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom') parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--dataset', type=str, default='Dahoas/rm-static') parser.add_argument('--save_path', type=str, default='rm_ckpt.pth') parser.add_argument('--max_epochs', type=int, default=10) parser.add_argument('--batch_size', type=int, default=4) parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") args = parser.parse_args() train(args)