ColossalAI/applications/Chat/coati/trainer/sft.py

136 lines
5.5 KiB
Python

import math
import time
from typing import List, Optional
import torch
import torch.distributed as dist
import wandb
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer import get_scheduler
from .base import Trainer
from .callbacks import Callback
from .strategies import ColossalAIStrategy, Strategy
from .utils import is_rank_0, to_device
class SFTTrainer(Trainer):
"""
Trainer to use while training reward model.
Args:
model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training
train_dataloader: the dataloader to use for training
eval_dataloader: the dataloader to use for evaluation
batch_size (int, defaults to 1): the batch size while training
max_epochs (int, defaults to 2): the number of epochs to train
callbacks (List[Callback], defaults to []): the callbacks to call during training process
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
"""
def __init__(
self,
model,
strategy: Strategy,
optim: Optimizer,
train_dataloader: DataLoader,
eval_dataloader: DataLoader = None,
max_epochs: int = 2,
accumulation_steps: int = 8,
callbacks: List[Callback] = [],
) -> None:
if accumulation_steps > 1 and isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3:
raise ValueError("Accumulation steps are not supported in stage 3 of ColossalAI")
super().__init__(strategy, max_epochs, callbacks=callbacks)
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
self.model = model
self.optimizer = optim
self.accumulation_steps = accumulation_steps
num_update_steps_per_epoch = len(train_dataloader) // self.accumulation_steps
max_steps = math.ceil(self.max_epochs * num_update_steps_per_epoch)
self.scheduler = get_scheduler("cosine",
self.optimizer,
num_warmup_steps=math.ceil(max_steps * 0.03),
num_training_steps=max_steps)
def fit(self, logger, use_wandb: bool = False):
if use_wandb:
wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
wandb.watch(self.model)
total_loss = 0
# epoch_bar = tqdm(range(self.epochs), desc='Epochs', disable=not is_rank_0())
step_bar = tqdm(range(len(self.train_dataloader) // self.accumulation_steps * self.max_epochs),
desc=f'steps',
disable=not is_rank_0())
for epoch in range(self.max_epochs):
# process_bar = tqdm(range(len(self.train_dataloader)), desc=f'Train process for{epoch}', disable=not is_rank_0())
# train
self.model.train()
for batch_id, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
loss = outputs.loss
if loss >= 2.5 and is_rank_0():
logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}")
loss = loss / self.accumulation_steps
self.strategy.backward(loss, self.model, self.optimizer)
total_loss += loss.item()
# gradient accumulation
if (batch_id + 1) % self.accumulation_steps == 0:
self.strategy.optimizer_step(self.optimizer)
self.optimizer.zero_grad()
self.scheduler.step()
if is_rank_0() and use_wandb:
wandb.log({
"loss": total_loss / self.accumulation_steps,
"lr": self.scheduler.get_last_lr()[0],
"epoch": epoch,
"batch_id": batch_id
})
total_loss = 0
step_bar.update()
# if batch_id % log_interval == 0:
# logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}')
# wandb.log({"loss": loss.item()})
# process_bar.update()
# eval
if self.eval_dataloader is not None:
self.model.eval()
with torch.no_grad():
loss_sum = 0
num_seen = 0
for batch in self.eval_dataloader:
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"])
loss = outputs.loss
loss_sum += loss.item()
num_seen += batch["input_ids"].size(0)
loss_mean = loss_sum / num_seen
if dist.get_rank() == 0:
logger.info(f'Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}')
# epoch_bar.update()