""" SFT trainer """ import os from typing import Optional import torch from coati.trainer.utils import all_reduce_mean from coati.utils import AccumulativeMeanMeter, save_checkpoint from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from torch.utils.data import DataLoader from tqdm import trange from colossalai.booster import Booster from colossalai.cluster import DistCoordinator from .base import SLTrainer from .utils import is_rank_0, to_device class SFTTrainer(SLTrainer): """ Trainer to use while training reward model. Args: model (torch.nn.Module): the model to train strategy (Strategy): the strategy to use for training optim(Optimizer): the optimizer to use for training lr_scheduler(_LRScheduler): the lr scheduler to use for training max_epochs (int, defaults to 2): the number of epochs to train accumulation_steps (int, defaults to 8): the number of steps to accumulate gradients """ def __init__( self, model, booster: Booster, optim: Optimizer, lr_scheduler: _LRScheduler, max_epochs: int = 2, accumulation_steps: int = 8, start_epoch=0, save_interval: int = None, save_dir: str = None, coordinator: Optional[DistCoordinator] = None, ) -> None: super().__init__(booster, max_epochs, model, optim, start_epoch=start_epoch) self.accumulation_steps = accumulation_steps self.scheduler = lr_scheduler self.save_interval = save_interval self.save_dir = save_dir self.coordinator = coordinator self.num_train_step = 0 self.num_eval_step = 0 self.accumulative_meter = AccumulativeMeanMeter() def _before_fit( self, train_dataloader: DataLoader, eval_dataloader: Optional[DataLoader] = None, log_dir: Optional[str] = None, use_wandb: bool = False, ): """ Args: train_dataloader: the dataloader to use for training eval_dataloader: the dataloader to use for evaluation log_dir: the directory to save logs use_wandb: whether to use wandb for logging """ self.train_dataloader = train_dataloader self.eval_dataloader = eval_dataloader self.writer = None if use_wandb and is_rank_0(): assert log_dir is not None, "log_dir must be provided when use_wandb is True" import wandb wandb.init(project="Coati-sft", sync_tensorboard=True) if log_dir is not None and is_rank_0(): import os import time from torch.utils.tensorboard import SummaryWriter log_dir = os.path.join(log_dir, "sft") log_dir = os.path.join(log_dir, time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime())) self.writer = SummaryWriter(log_dir=log_dir) def _train(self, epoch: int): self.model.train() step_bar = trange( len(self.train_dataloader) // self.accumulation_steps, desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not is_rank_0(), ) for i, batch in enumerate(self.train_dataloader): batch = to_device(batch, torch.cuda.current_device()) batch_size = batch["input_ids"].size(0) outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) loss = outputs.loss self.booster.backward(loss=loss, optimizer=self.optimizer) loss_mean = all_reduce_mean(tensor=loss) self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item()) # Gradient accumulation if (i + 1) % self.accumulation_steps == 0: self.optimizer.step() self.optimizer.zero_grad() self.scheduler.step() step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")}) if self.writer: self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step) self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step) self.num_train_step += 1 self.accumulative_meter.reset() step_bar.update() # Save checkpoint if ( self.save_dir is not None and self.save_interval is not None and (self.num_train_step + 1) % self.save_interval == 0 ): save_checkpoint( save_dir=self.save_dir, booster=self.booster, model=self.model, optimizer=self.optimizer, lr_scheduler=self.scheduler, epoch=epoch, step=self.num_train_step + 1, batch_size=batch_size, coordinator=self.coordinator, ) self.coordinator.print_on_master( f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}" ) step_bar.close() def _eval(self, epoch: int): if self.eval_dataloader is None: self.coordinator.print_on_master("No eval dataloader is provided, skip evaluation") return self.accumulative_meter.reset() self.model.eval() with torch.no_grad(): step_bar = trange( len(self.eval_dataloader), desc=f"Epoch {epoch + 1}/{self.max_epochs}", disable=not is_rank_0(), ) for batch in self.eval_dataloader: batch = to_device(batch, torch.cuda.current_device()) outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"]) loss_mean = all_reduce_mean(tensor=outputs.loss) self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0)) step_bar.update() loss_mean = self.accumulative_meter.get("loss") msg = "Evaluation Result:\n" for tag in ["loss"]: msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" self.coordinator.print_on_master(msg) os.makedirs(self.save_dir, exist_ok=True) with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: f.write(msg) step_bar.close()