[chatgpt] support colossalai strategy to train rm (#2742)

* [chatgpt]fix train_rm bug with lora

* [chatgpt]support colossalai strategy to train rm

* fix pre-commit

* fix pre-commit 2
pull/2744/head
BlueRum 2 years ago committed by GitHub
parent 648183a960
commit 613efebc5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,10 +3,13 @@ from abc import ABC
import loralib as lora import loralib as lora
from chatgpt.dataset import RewardDataset from chatgpt.dataset import RewardDataset
from chatgpt.nn import PairWiseLoss from chatgpt.nn import PairWiseLoss
from torch.optim import Adam from torch.optim import Adam, Optimizer
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from .strategies import Strategy
from .utils import is_rank_0
class RewardModelTrainer(ABC): class RewardModelTrainer(ABC):
""" """
@ -14,32 +17,41 @@ class RewardModelTrainer(ABC):
Args: Args:
model (torch.nn.Module): the model to train model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training
train_dataset (RewardDataset): the dataset to use for training train_dataset (RewardDataset): the dataset to use for training
eval_dataset (RewardDataset): the dataset to use for evaluation eval_dataset (RewardDataset): the dataset to use for evaluation
batch_size (int, defaults to 1): the batch size while training batch_size (int, defaults to 1): the batch size while training
num_epochs (int, defaults to 2): the number of epochs to train max_epochs (int, defaults to 2): the number of epochs to train
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
""" """
def __init__(self, def __init__(
model, self,
train_dataset: RewardDataset, model,
eval_dataset: RewardDataset, strategy: Strategy,
batch_size: int = 1, optim: Optimizer,
num_epochs: int = 2, train_dataset: RewardDataset,
optim_kwargs: dict = {'lr': 1e-4}) -> None: eval_dataset: RewardDataset,
batch_size: int = 1,
max_epochs: int = 2,
) -> None:
super().__init__() super().__init__()
self.model = model self.strategy = strategy
self.epochs = max_epochs
self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size) self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size) self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size)
self.model = strategy.setup_model(model)
self.loss_fn = PairWiseLoss() self.loss_fn = PairWiseLoss()
self.optimizer = Adam(self.model.parameters(), **optim_kwargs) self.optimizer = strategy.setup_optimizer(optim, self.model)
self.epochs = num_epochs
def fit(self, use_lora): def fit(self, use_lora):
epoch_bar = tqdm(range(self.epochs), desc='Train epoch') epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0())
for epoch in range(self.epochs): for epoch in range(self.epochs):
step_bar = tqdm(range(self.train_dataloader.__len__()), desc='Train step of epoch %d' % epoch) step_bar = tqdm(range(self.train_dataloader.__len__()),
desc='Train step of epoch %d' % epoch,
disable=not is_rank_0())
# train # train
if use_lora > 0: if use_lora > 0:
print("Using Lora") print("Using Lora")
@ -54,8 +66,8 @@ class RewardModelTrainer(ABC):
chosen_reward = self.model(chosen_ids, attention_mask=c_mask) chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask) reject_reward = self.model(reject_ids, attention_mask=r_mask)
loss = self.loss_fn(chosen_reward, reject_reward) loss = self.loss_fn(chosen_reward, reject_reward)
loss.backward() self.strategy.backward(loss, self.model, self.optimizer)
self.optimizer.step() self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad() self.optimizer.zero_grad()
step_bar.update() step_bar.update()
step_bar.set_postfix({'loss': loss.item()}) step_bar.set_postfix({'loss': loss.item()})

@ -13,6 +13,6 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
} }
set_n_least_used_CUDA_VISIBLE_DEVICES 1 set_n_least_used_CUDA_VISIBLE_DEVICES 2
python train_dummy.py --model bloom --pretrain '/data2/users/lczht/bloom-560m' --lora_rank 16 torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy colossalai_zero2

@ -13,6 +13,6 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
} }
set_n_least_used_CUDA_VISIBLE_DEVICES 1 set_n_least_used_CUDA_VISIBLE_DEVICES 2
python train_prompts.py prompts.csv --pretrain '/data2/users/lczht/bloom-560m' --lora_rank 16 torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2

@ -5,33 +5,55 @@ import torch
from chatgpt.dataset import RewardDataset from chatgpt.dataset import RewardDataset
from chatgpt.nn import BLOOMRM from chatgpt.nn import BLOOMRM
from chatgpt.trainer import RewardModelTrainer from chatgpt.trainer import RewardModelTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from datasets import load_dataset from datasets import load_dataset
from torch.optim import Adam
from transformers import BloomTokenizerFast from transformers import BloomTokenizerFast
from colossalai.nn.optimizer import HybridAdam
def train(args): def train(args):
# configure strategy
if args.strategy == 'naive':
strategy = NaiveStrategy()
elif args.strategy == 'ddp':
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
elif args.strategy == 'colossalai_zero2':
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain) tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
model = BLOOMRM(pretrained=args.pretrain) model = BLOOMRM(pretrained=args.pretrain).cuda()
model.cuda()
max_len = 1024 max_len = 1024
# configure optimizer
if args.strategy.startswith('colossalai'):
optim = HybridAdam(model.parameters(), lr=5e-5)
else:
optim = Adam(model.parameters(), lr=5e-5)
# prepare for data and dataset # prepare for data and dataset
data = load_dataset(args.dataset) data = load_dataset(args.dataset)
train_data = data["train"] train_data = data["train"].select(range(100))
eval_data = data['test'] eval_data = data['test'].select(range(5))
train_dataset = RewardDataset(train_data, tokenizer, max_len) train_dataset = RewardDataset(train_data, tokenizer, max_len)
eval_dataset = RewardDataset(eval_data, tokenizer, max_len) eval_dataset = RewardDataset(eval_data, tokenizer, max_len)
# batch_size here is expected to be C(k,2), k means # response of each prompt # batch_size here is expected to be C(k,2), k means # response of each prompt
# be limited with the format of dataset 'Dahoas/rm-static', we'd better use batch_size as 1 # be limited with the format of dataset 'Dahoas/rm-static', we'd better use batch_size as 1
trainer = RewardModelTrainer(model=model, trainer = RewardModelTrainer(model=model,
strategy=strategy,
optim=optim,
train_dataset=train_dataset, train_dataset=train_dataset,
eval_dataset=eval_dataset, eval_dataset=eval_dataset,
batch_size=args.batch_size, batch_size=args.batch_size,
num_epochs=args.max_epochs) max_epochs=args.max_epochs)
trainer.fit(use_lora=args.lora_rank) trainer.fit(use_lora=args.lora_rank)
@ -43,6 +65,9 @@ def train(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default='Dahoas/rm-static') parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
parser.add_argument('--save_path', type=str, default='rm_ckpt.pth') parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')

@ -13,6 +13,6 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
} }
set_n_least_used_CUDA_VISIBLE_DEVICES 1 set_n_least_used_CUDA_VISIBLE_DEVICES 2
python train_reward_model.py --pretrain '/data2/users/lczht/bloom-560m' --lora_rank 16 torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain '/data2/users/lczht/bloom-560m' --strategy colossalai_zero2

Loading…
Cancel
Save