ColossalAI/applications/ChatGPT/chatgpt/trainer/rm.py

78 lines
3.3 KiB
Python

from abc import ABC
import loralib as lora
from chatgpt.dataset import RewardDataset
from chatgpt.nn import PairWiseLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from tqdm import tqdm
class RewardModelTrainer(ABC):
"""
Trainer to use while training reward model.
Args:
model (torch.nn.Module): the model to train
train_dataset (RewardDataset): the dataset to use for training
eval_dataset (RewardDataset): the dataset to use for evaluation
batch_size (int, defaults to 1): the batch size while training
num_epochs (int, defaults to 2): the number of epochs to train
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
"""
def __init__(self,
model,
train_dataset: RewardDataset,
eval_dataset: RewardDataset,
batch_size: int = 1,
num_epochs: int = 2,
optim_kwargs: dict = {'lr': 1e-4}) -> None:
super().__init__()
self.model = model
self.train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size)
self.loss_fn = PairWiseLoss()
self.optimizer = Adam(self.model.parameters(), **optim_kwargs)
self.epochs = num_epochs
def fit(self, use_lora):
epoch_bar = tqdm(range(self.epochs), desc='Train epoch')
for epoch in range(self.epochs):
step_bar = tqdm(range(self.train_dataloader.__len__()), desc='Train step of epoch %d' % epoch)
# train
if use_lora > 0:
print("Using Lora")
lora.mark_only_lora_as_trainable(self.model)
else:
self.model.train()
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
chosen_ids = chosen_ids.squeeze(1).cuda()
c_mask = c_mask.squeeze(1).cuda()
reject_ids = reject_ids.squeeze(1).cuda()
r_mask = r_mask.squeeze(1).cuda()
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
loss = self.loss_fn(chosen_reward, reject_reward)
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
step_bar.update()
step_bar.set_postfix({'loss': loss.item()})
# eval
self.model.eval()
for chosen_ids, c_mask, reject_ids, r_mask in self.eval_dataloader:
dist = 0
chosen_ids = chosen_ids.squeeze(1).cuda()
c_mask = c_mask.squeeze(1).cuda()
reject_ids = reject_ids.squeeze(1).cuda()
r_mask = r_mask.squeeze(1).cuda()
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
reject_reward = self.model(reject_ids, attention_mask=r_mask)
dist += (chosen_reward - reject_reward)
dist_mean = dist / self.eval_dataloader.__len__()
epoch_bar.update()
step_bar.set_postfix({'loss': loss.item(), 'dist_mean': dist_mean.item()})
step_bar.close()