mirror of https://github.com/hpcaitech/ColossalAI
131 lines
4.8 KiB
Python
131 lines
4.8 KiB
Python
from typing import Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import tqdm
|
|
from torch.optim import Optimizer
|
|
from torch.optim.lr_scheduler import _LRScheduler
|
|
from torch.utils.data import DataLoader
|
|
|
|
from colossalai.logging import DistributedLogger
|
|
|
|
from .base import SLTrainer
|
|
from .strategies import GeminiStrategy, Strategy
|
|
from .utils import is_rank_0, to_device
|
|
|
|
|
|
class SFTTrainer(SLTrainer):
|
|
"""
|
|
Trainer to use while training reward model.
|
|
|
|
Args:
|
|
model (torch.nn.Module): the model to train
|
|
strategy (Strategy): the strategy to use for training
|
|
optim(Optimizer): the optimizer to use for training
|
|
lr_scheduler(_LRScheduler): the lr scheduler to use for training
|
|
max_epochs (int, defaults to 2): the number of epochs to train
|
|
accumulation_steps (int, defaults to 8): the number of steps to accumulate gradients
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model,
|
|
strategy: Strategy,
|
|
optim: Optimizer,
|
|
lr_scheduler: _LRScheduler,
|
|
max_epochs: int = 2,
|
|
accumulation_steps: int = 8,
|
|
) -> None:
|
|
if accumulation_steps > 1:
|
|
assert not isinstance(
|
|
strategy, GeminiStrategy
|
|
), "Accumulation steps are not supported in stage 3 of ColossalAI"
|
|
|
|
super().__init__(strategy, max_epochs, model, optim)
|
|
|
|
self.accumulation_steps = accumulation_steps
|
|
self.scheduler = lr_scheduler
|
|
|
|
self.num_train_step = 0
|
|
self.num_eval_step = 0
|
|
|
|
def _train(self, epoch: int):
|
|
self.model.train()
|
|
step_bar = tqdm.trange(
|
|
len(self.train_dataloader) // self.accumulation_steps,
|
|
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
|
disable=not is_rank_0(),
|
|
)
|
|
for i, batch in enumerate(self.train_dataloader):
|
|
batch = to_device(batch, torch.cuda.current_device())
|
|
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
|
|
loss = outputs.loss / self.accumulation_steps
|
|
self.total_loss += loss.item()
|
|
self.strategy.backward(loss, self.model, self.optimizer)
|
|
# gradient accumulation
|
|
if (i + 1) % self.accumulation_steps == 0:
|
|
self.strategy.optimizer_step(self.optimizer)
|
|
self.optimizer.zero_grad()
|
|
self.scheduler.step()
|
|
if self.writer:
|
|
self.writer.add_scalar("train/loss", self.total_loss, self.num_train_step)
|
|
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
|
|
self.num_train_step += 1
|
|
self.total_loss = 0
|
|
step_bar.update()
|
|
step_bar.close()
|
|
|
|
def _eval(self, epoch: int):
|
|
if self.eval_dataloader is not None:
|
|
self.model.eval()
|
|
with torch.no_grad():
|
|
loss_sum, num_seen = 0, 0
|
|
for batch in self.eval_dataloader:
|
|
batch = to_device(batch, torch.cuda.current_device())
|
|
outputs = self.model(
|
|
batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]
|
|
)
|
|
loss_sum += outputs.loss.item()
|
|
num_seen += batch["input_ids"].size(0)
|
|
loss_mean = loss_sum / num_seen
|
|
if dist.get_rank() == 0:
|
|
self.logger.info(f"Eval Epoch {epoch}/{self.max_epochs} loss {loss_mean}")
|
|
if self.writer:
|
|
self.writer.add_scalar("eval/loss", loss_mean, self.num_eval_step)
|
|
self.num_eval_step += 1
|
|
|
|
def _before_fit(
|
|
self,
|
|
train_dataloader: DataLoader,
|
|
eval_dataloader: Optional[DataLoader] = None,
|
|
logger: Optional[DistributedLogger] = None,
|
|
log_dir: Optional[str] = None,
|
|
use_wandb: bool = False,
|
|
):
|
|
"""
|
|
Args:
|
|
train_dataloader: the dataloader to use for training
|
|
eval_dataloader: the dataloader to use for evaluation
|
|
"""
|
|
self.train_dataloader = train_dataloader
|
|
self.eval_dataloader = eval_dataloader
|
|
|
|
self.logger = logger
|
|
self.writer = None
|
|
if use_wandb and is_rank_0():
|
|
assert log_dir is not None, "log_dir must be provided when use_wandb is True"
|
|
import wandb
|
|
|
|
wandb.init(project="Coati-sft", sync_tensorboard=True)
|
|
if log_dir is not None and is_rank_0():
|
|
import os
|
|
import time
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
log_dir = os.path.join(log_dir, "sft")
|
|
log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()))
|
|
self.writer = SummaryWriter(log_dir=log_dir)
|
|
|
|
self.total_loss = 0
|