You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/amp/torch_amp/torch_amp.py

99 lines
3.3 KiB

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.cuda.amp as torch_amp
import torch.nn as nn
from torch import Tensor
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import clip_grad_norm_fp32
from ._grad_scaler import GradScaler
class TorchAMPOptimizer(ColossalaiOptimizer):
"""A wrapper class which integrate Pytorch AMP with an optimizer
Args:
optim (torch.optim.Optimizer): A normal optimizer like Adam or SGD.
init_scale (float, optional, default=2.**16): Initial scale factor.
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
:meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
:meth:`update` if inf/NaN gradients occur in an iteration.
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
that must occur for the scale to be multiplied by ``growth_factor``.
enabled (bool, optional, default=True): If ``False``, disables gradient scaling. :meth:`step` simply
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
"""
def __init__(self, optim: Optimizer, *args, **kwargs):
super().__init__(optim)
self.scaler = GradScaler(*args, **kwargs)
def backward(self, loss: Tensor):
"""Backward with torch amp gradient scaler
Args:
loss (torch.Tensor): Loss computed by a loss function
"""
self.scaler.scale(loss).backward()
def step(self):
"""Update the parameters of the model
"""
self.scaler.step(self.optim)
self.scaler.update()
def clip_grad_norm(self, model: nn.Module, max_norm: float):
"""Apply gradient clipping to the model parameters
Args:
model (torch.nn.Module): Your model object
max_norm (float): Max norm value for gradient clipping
"""
if max_norm > 0.0:
self.scaler.unscale_(self.optim)
clip_grad_norm_fp32(model.parameters(), max_norm)
class TorchAMPModel(nn.Module):
"""A wrapper class for a model object which executes forward with values automatically
cast to fp16
Args:
model (:class:`torch.nn.Module`): a torch model instance
"""
def __init__(self, model: nn.Module) -> None:
super().__init__()
self.model = model
@torch_amp.autocast()
def forward(self, *args, **kwargs):
"""
Execute forward under the torch amp context
"""
return self.model(*args, **kwargs)
class TorchAMPLoss(nn.Module):
"""A wrapper class for a criterion object which computes the loss in mixed-precision context
Args:
loss (torch.nn.modules.loss._Loss): A loss function object
"""
def __init__(self, loss: _Loss):
super().__init__()
self.loss = loss
@torch_amp.autocast()
def forward(self, *args, **kwargs):
"""
Execute forward under the torch amp context
"""
return self.loss(*args, **kwargs)