[chatgpt] support colossalai strategy to train rm (#2742)

* [chatgpt]fix train_rm bug with lora

* [chatgpt]support colossalai strategy to train rm

* fix pre-commit

* fix pre-commit 2
pull/2744/head
BlueRum 2023-02-16 11:24:07 +08:00 committed by GitHub
parent 648183a960
commit 613efebc5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 66 additions and 29 deletions

View File

@ -3,10 +3,13 @@ from abc import ABC
import loralib as lora
from chatgpt.dataset import RewardDataset
from chatgpt.nn import PairWiseLoss
from torch.optim import Adam
from torch.optim import Adam, Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm
from .strategies import Strategy
from .utils import is_rank_0
class RewardModelTrainer(ABC):
"""
@ -14,32 +17,41 @@ class RewardModelTrainer(ABC):
Args:
model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training
train_dataset (RewardDataset): the dataset to use for training
eval_dataset (RewardDataset): the dataset to use for evaluation
batch_size (int, defaults to 1): the batch size while training
num_epochs (int, defaults to 2): the number of epochs to train
max_epochs (int, defaults to 2): the number of epochs to train
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
"""
def __init__(self,
model,
train_dataset: RewardDataset,
eval_dataset: RewardDataset,
batch_size: int = 1,
num_epochs: int = 2,
optim_kwargs: dict = {'lr': 1e-4}) -> None:
def __init__(
self,
model,
strategy: Strategy,
optim: Optimizer,
train_dataset: RewardDataset,
eval_dataset: RewardDataset,
batch_size: int = 1,
max_epochs: int = 2,
) -> None:
super().__init__()
self.model = model
self.strategy = strategy
self.epochs = max_epochs
self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size)
self.model = strategy.setup_model(model)
self.loss_fn = PairWiseLoss()
self.optimizer = Adam(self.model.parameters(), **optim_kwargs)
self.epochs = num_epochs
self.optimizer = strategy.setup_optimizer(optim, self.model)
def fit(self, use_lora):
epoch_bar = tqdm(range(self.epochs), desc='Train epoch')
epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0())
for epoch in range(self.epochs):
step_bar = tqdm(range(self.train_dataloader.__len__()), desc='Train step of epoch %d' % epoch)
step_bar = tqdm(range(self.train_dataloader.__len__()),
desc='Train step of epoch %d' % epoch,
disable=not is_rank_0())
# train
if use_lora > 0:
print("Using Lora")
@ -54,8 +66,8 @@ class RewardModelTrainer(ABC):
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
loss = self.loss_fn(chosen_reward, reject_reward)
loss.backward()
self.optimizer.step()
self.strategy.backward(loss, self.model, self.optimizer)
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
step_bar.update()
step_bar.set_postfix({'loss': loss.item()})

View File

@ -13,6 +13,6 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 1
set_n_least_used_CUDA_VISIBLE_DEVICES 2
python train_dummy.py --model bloom --pretrain '/data2/users/lczht/bloom-560m' --lora_rank 16
torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy colossalai_zero2

View File

@ -13,6 +13,6 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 1
set_n_least_used_CUDA_VISIBLE_DEVICES 2
python train_prompts.py prompts.csv --pretrain '/data2/users/lczht/bloom-560m' --lora_rank 16
torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2

View File

@ -5,33 +5,55 @@ import torch
from chatgpt.dataset import RewardDataset
from chatgpt.nn import BLOOMRM
from chatgpt.trainer import RewardModelTrainer
from chatgpt.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
from datasets import load_dataset
from torch.optim import Adam
from transformers import BloomTokenizerFast
from colossalai.nn.optimizer import HybridAdam
def train(args):
# configure strategy
if args.strategy == 'naive':
strategy = NaiveStrategy()
elif args.strategy == 'ddp':
strategy = DDPStrategy()
elif args.strategy == 'colossalai_gemini':
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
elif args.strategy == 'colossalai_zero2':
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
# configure model
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
tokenizer.pad_token = tokenizer.eos_token
model = BLOOMRM(pretrained=args.pretrain)
model.cuda()
model = BLOOMRM(pretrained=args.pretrain).cuda()
max_len = 1024
# configure optimizer
if args.strategy.startswith('colossalai'):
optim = HybridAdam(model.parameters(), lr=5e-5)
else:
optim = Adam(model.parameters(), lr=5e-5)
# prepare for data and dataset
data = load_dataset(args.dataset)
train_data = data["train"]
eval_data = data['test']
train_data = data["train"].select(range(100))
eval_data = data['test'].select(range(5))
train_dataset = RewardDataset(train_data, tokenizer, max_len)
eval_dataset = RewardDataset(eval_data, tokenizer, max_len)
# batch_size here is expected to be C(k,2), k means # response of each prompt
# be limited with the format of dataset 'Dahoas/rm-static', we'd better use batch_size as 1
trainer = RewardModelTrainer(model=model,
strategy=strategy,
optim=optim,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
batch_size=args.batch_size,
num_epochs=args.max_epochs)
max_epochs=args.max_epochs)
trainer.fit(use_lora=args.lora_rank)
@ -43,6 +65,9 @@ def train(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')

View File

@ -13,6 +13,6 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
}
set_n_least_used_CUDA_VISIBLE_DEVICES 1
set_n_least_used_CUDA_VISIBLE_DEVICES 2
python train_reward_model.py --pretrain '/data2/users/lczht/bloom-560m' --lora_rank 16
torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain '/data2/users/lczht/bloom-560m' --strategy colossalai_zero2