#!/usr/bin/env python # -*- encoding: utf-8 -*- import torch.cuda.amp as torch_amp import torch.nn as nn from torch import Tensor from torch.nn.modules.loss import _Loss from torch.optim import Optimizer from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.utils import clip_grad_norm_fp32 from ._grad_scaler import GradScaler class TorchAMPOptimizer(ColossalaiOptimizer): """A wrapper class which integrate Pytorch AMP with an optimizer Args: optim (torch.optim.Optimizer): A normal optimizer like Adam or SGD. init_scale (float, optional, default=2.**16): Initial scale factor. growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during :meth:`update` if inf/NaN gradients occur in an iteration. growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients that must occur for the scale to be multiplied by ``growth_factor``. enabled (bool, optional, default=True): If ``False``, disables gradient scaling. :meth:`step` simply invokes the underlying ``optimizer.step()``, and other methods become no-ops. """ def __init__(self, optim: Optimizer, *args, **kwargs): super().__init__(optim) self.scaler = GradScaler(*args, **kwargs) def backward(self, loss: Tensor): """Backward with torch amp gradient scaler Args: loss (torch.Tensor): Loss computed by a loss function """ self.scaler.scale(loss).backward() def step(self): """Update the parameters of the model """ self.scaler.step(self.optim) self.scaler.update() def clip_grad_norm(self, model: nn.Module, max_norm: float): """Apply gradient clipping to the model parameters Args: model (torch.nn.Module): Your model object max_norm (float): Max norm value for gradient clipping """ if max_norm > 0.0: self.scaler.unscale_(self.optim) clip_grad_norm_fp32(model.parameters(), max_norm) class TorchAMPModel(nn.Module): """A wrapper class for a model object which executes forward with values automatically cast to fp16 Args: model (:class:`torch.nn.Module`): a torch model instance """ def __init__(self, model: nn.Module) -> None: super().__init__() self.model = model @torch_amp.autocast() def forward(self, *args, **kwargs): """ Execute forward under the torch amp context """ return self.model(*args, **kwargs) class TorchAMPLoss(nn.Module): """A wrapper class for a criterion object which computes the loss in mixed-precision context Args: loss (torch.nn.modules.loss._Loss): A loss function object """ def __init__(self, loss: _Loss): super().__init__() self.loss = loss @torch_amp.autocast() def forward(self, *args, **kwargs): """ Execute forward under the torch amp context """ return self.loss(*args, **kwargs)