2023-03-28 12:25:36 +00:00
|
|
|
import argparse
|
|
|
|
from random import randint
|
|
|
|
|
|
|
|
import loralib as lora
|
|
|
|
import torch
|
|
|
|
from coati.dataset import HhRlhfDataset, RmStaticDataset
|
|
|
|
from coati.models import LogExpLoss, LogSigLoss
|
|
|
|
from coati.models.base import RewardModel
|
|
|
|
from coati.models.bloom import BLOOMRM
|
|
|
|
from coati.models.deberta import DebertaRM
|
|
|
|
from coati.models.gpt import GPTRM
|
|
|
|
from coati.models.llama import LlamaRM
|
|
|
|
from coati.models.opt import OPTRM
|
2023-04-03 02:11:03 +00:00
|
|
|
from coati.models.roberta import RoBERTaRM
|
2023-03-28 12:25:36 +00:00
|
|
|
from coati.trainer import RewardModelTrainer
|
|
|
|
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
|
|
|
from coati.utils import prepare_llama_tokenizer_and_embedding
|
|
|
|
from datasets import load_dataset
|
|
|
|
from torch.optim import Adam
|
2023-04-03 02:11:03 +00:00
|
|
|
from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer, LlamaTokenizer, RobertaTokenizer
|
2023-03-28 12:25:36 +00:00
|
|
|
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
|
|
|
|
|
|
|
from colossalai.nn.optimizer import HybridAdam
|
|
|
|
|
|
|
|
|
|
|
|
def train(args):
|
|
|
|
# configure strategy
|
|
|
|
if args.strategy == 'naive':
|
|
|
|
strategy = NaiveStrategy()
|
|
|
|
elif args.strategy == 'ddp':
|
|
|
|
strategy = DDPStrategy()
|
|
|
|
elif args.strategy == 'colossalai_gemini':
|
|
|
|
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
|
|
|
|
elif args.strategy == 'colossalai_zero2':
|
|
|
|
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
|
|
|
else:
|
|
|
|
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
|
|
|
|
|
|
|
# configure model
|
|
|
|
with strategy.model_init_context():
|
|
|
|
if args.model == 'bloom':
|
|
|
|
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
|
|
elif args.model == 'opt':
|
|
|
|
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
|
|
elif args.model == 'gpt2':
|
|
|
|
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
|
|
elif args.model == 'deberta':
|
|
|
|
model = DebertaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
|
|
|
elif args.model == 'llama':
|
|
|
|
model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
2023-04-03 02:11:03 +00:00
|
|
|
elif args.model == 'roberta':
|
|
|
|
model = RoBERTaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
2023-03-28 12:25:36 +00:00
|
|
|
else:
|
|
|
|
raise ValueError(f'Unsupported model "{args.model}"')
|
|
|
|
|
|
|
|
if args.model_path is not None:
|
|
|
|
state_dict = torch.load(args.model_path)
|
|
|
|
model.load_state_dict(state_dict)
|
|
|
|
|
|
|
|
model = model.to(torch.float16)
|
|
|
|
|
|
|
|
# configure tokenizer
|
|
|
|
if args.model == 'gpt2':
|
|
|
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
|
|
|
elif args.model == 'bloom':
|
|
|
|
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
|
|
|
|
elif args.model == 'opt':
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
|
|
|
elif args.model == 'deberta':
|
|
|
|
tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-large')
|
|
|
|
elif args.model == 'llama':
|
|
|
|
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
|
2023-04-03 02:11:03 +00:00
|
|
|
elif args.model == 'roberta':
|
|
|
|
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
2023-03-28 12:25:36 +00:00
|
|
|
else:
|
|
|
|
raise ValueError(f'Unsupported model "{args.model}"')
|
|
|
|
max_len = args.max_len
|
|
|
|
|
|
|
|
if args.model == 'llama':
|
|
|
|
tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model)
|
|
|
|
else:
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
|
|
|
|
# configure optimizer
|
|
|
|
if args.strategy.startswith('colossalai'):
|
|
|
|
optim = HybridAdam(model.parameters(), lr=5e-6)
|
|
|
|
else:
|
|
|
|
optim = Adam(model.parameters(), lr=5e-6)
|
|
|
|
|
|
|
|
# configure loss function
|
|
|
|
if args.loss_fn == 'log_sig':
|
|
|
|
loss_fn = LogSigLoss()
|
|
|
|
elif args.loss_fn == 'log_exp':
|
|
|
|
loss_fn = LogExpLoss()
|
|
|
|
else:
|
|
|
|
raise ValueError(f'Unsupported loss function "{args.loss_fn}"')
|
|
|
|
|
|
|
|
# prepare for data and dataset
|
|
|
|
if args.subset is not None:
|
|
|
|
data = load_dataset(args.dataset, data_dir=args.subset)
|
|
|
|
else:
|
|
|
|
data = load_dataset(args.dataset)
|
|
|
|
|
|
|
|
if args.test:
|
|
|
|
train_data = data['train'].select(range(100))
|
|
|
|
eval_data = data['test'].select(range(10))
|
|
|
|
else:
|
|
|
|
train_data = data['train']
|
|
|
|
eval_data = data['test']
|
|
|
|
valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5)))
|
|
|
|
|
|
|
|
if args.dataset == 'Dahoas/rm-static':
|
|
|
|
train_dataset = RmStaticDataset(train_data, tokenizer, max_len)
|
|
|
|
valid_dataset = RmStaticDataset(valid_data, tokenizer, max_len)
|
|
|
|
eval_dataset = RmStaticDataset(eval_data, tokenizer, max_len)
|
|
|
|
elif args.dataset == 'Anthropic/hh-rlhf':
|
|
|
|
train_dataset = HhRlhfDataset(train_data, tokenizer, max_len)
|
|
|
|
valid_dataset = HhRlhfDataset(valid_data, tokenizer, max_len)
|
|
|
|
eval_dataset = HhRlhfDataset(eval_data, tokenizer, max_len)
|
|
|
|
else:
|
|
|
|
raise ValueError(f'Unsupported dataset "{args.dataset}"')
|
|
|
|
|
|
|
|
trainer = RewardModelTrainer(model=model,
|
|
|
|
strategy=strategy,
|
|
|
|
optim=optim,
|
|
|
|
loss_fn=loss_fn,
|
|
|
|
train_dataset=train_dataset,
|
|
|
|
valid_dataset=valid_dataset,
|
|
|
|
eval_dataset=eval_dataset,
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
max_epochs=args.max_epochs)
|
|
|
|
|
|
|
|
trainer.fit()
|
|
|
|
# save model checkpoint after fitting on only rank0
|
|
|
|
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
|
|
|
|
# save optimizer checkpoint on all ranks
|
|
|
|
if args.need_optim_ckpt:
|
|
|
|
strategy.save_optimizer(trainer.optimizer,
|
|
|
|
'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
|
|
|
|
only_rank0=False)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument('--strategy',
|
|
|
|
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
|
|
|
default='naive')
|
2023-04-03 02:11:03 +00:00
|
|
|
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta', 'llama', 'roberta'], default='bloom')
|
2023-03-28 12:25:36 +00:00
|
|
|
parser.add_argument('--pretrain', type=str, default=None)
|
|
|
|
parser.add_argument('--model_path', type=str, default=None)
|
|
|
|
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
|
|
|
parser.add_argument('--dataset',
|
|
|
|
type=str,
|
|
|
|
choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'],
|
|
|
|
default='Dahoas/rm-static')
|
|
|
|
parser.add_argument('--subset', type=str, default=None)
|
|
|
|
parser.add_argument('--save_path', type=str, default='rm_ckpt')
|
|
|
|
parser.add_argument('--max_epochs', type=int, default=1)
|
|
|
|
parser.add_argument('--batch_size', type=int, default=1)
|
|
|
|
parser.add_argument('--max_len', type=int, default=512)
|
|
|
|
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
|
|
|
parser.add_argument('--loss_fn', type=str, default='log_sig', choices=['log_sig', 'log_exp'])
|
|
|
|
parser.add_argument('--test', type=bool, default=False)
|
|
|
|
args = parser.parse_args()
|
|
|
|
train(args)
|