ColossalAI/applications/ChatGPT/examples/train_reward_model.py

101 lines
4.2 KiB
Python

import argparse
import loralib as lora
import torch
from chatgpt.dataset import RewardDataset
from chatgpt.nn import BLOOMRM, GPTRM, OPTRM
from chatgpt.trainer import RewardModelTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from datasets import load_dataset
from torch.optim import Adam
from transformers import AutoTokenizer, BloomTokenizerFast
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from colossalai.nn.optimizer import HybridAdam
def train(args):
# configure strategy
if args.strategy == 'naive':
strategy = NaiveStrategy()
elif args.strategy == 'ddp':
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
elif args.strategy == 'colossalai_zero2':
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
with strategy.model_init_context():
if args.model == 'bloom':
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
elif args.model == 'opt':
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
elif args.model == 'gpt2':
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).cuda()
else:
raise ValueError(f'Unsupported model "{args.model}"')
# configure tokenizer
if args.model == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'bloom':
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
elif args.model == 'opt':
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
else:
raise ValueError(f'Unsupported model "{args.model}"')
tokenizer.pad_token = tokenizer.eos_token
max_len = 512
# configure optimizer
if args.strategy.startswith('colossalai'):
optim = HybridAdam(model.parameters(), lr=5e-5)
else:
optim = Adam(model.parameters(), lr=5e-5)
# prepare for data and dataset
data = load_dataset(args.dataset)
train_data = data["train"].select(range(100))
eval_data = data['test'].select(range(5))
train_dataset = RewardDataset(train_data, tokenizer, max_len)
eval_dataset = RewardDataset(eval_data, tokenizer, max_len)
# batch_size here is expected to be C(k,2), k means # response of each prompt
# be limited with the format of dataset 'Dahoas/rm-static', we'd better use batch_size as 1
trainer = RewardModelTrainer(model=model,
strategy=strategy,
optim=optim,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
batch_size=args.batch_size,
max_epochs=args.max_epochs)
trainer.fit(use_lora=args.lora_rank)
# save model checkpoint after fitting on only rank0
strategy.save_model(model, 'rm_checkpoint.pt', only_rank0=True)
# save optimizer checkpoint on all ranks
strategy.save_optimizer(optim, 'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()), only_rank0=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom')
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')
parser.add_argument('--max_epochs', type=int, default=10)
parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
args = parser.parse_args()
train(args)