mirror of https://github.com/hpcaitech/ColossalAI
85 lines
2.4 KiB
Python
85 lines
2.4 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import torch.nn as nn
|
|
import torch.cuda.amp as torch_amp
|
|
|
|
from torch import Tensor
|
|
from torch.nn.modules.loss import _Loss
|
|
from torch.optim import Optimizer
|
|
from ._grad_scaler import GradScaler
|
|
|
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
|
from colossalai.utils import clip_grad_norm_fp32
|
|
|
|
|
|
class TorchAMPOptimizer(ColossalaiOptimizer):
|
|
"""A wrapper class which integrate pytorch amp with an optimizer
|
|
|
|
:param optim: A normal optimizer like Adam or SGD
|
|
:param args: Args used to initialize gradient scaler
|
|
:param kwargs: Kwargs used to initialize gradient scaler
|
|
|
|
:type optim: torch.optim.Optimizer
|
|
"""
|
|
|
|
def __init__(self, optim: Optimizer, *args, **kwargs):
|
|
super().__init__(optim)
|
|
self.scaler = GradScaler(*args, **kwargs)
|
|
|
|
def backward(self, loss: Tensor):
|
|
"""Backward with torch amp gradient scaler
|
|
|
|
:param loss: Loss computed by a loss function
|
|
:type loss: torch.Tensor
|
|
"""
|
|
self.scaler.scale(loss).backward()
|
|
|
|
def step(self):
|
|
"""Update the parameters of the model
|
|
"""
|
|
self.scaler.step(self.optim)
|
|
self.scaler.update()
|
|
|
|
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
|
"""Apply gradient clipping to the model parameters
|
|
|
|
:param model: Your model object
|
|
:type model: torch.nn.Module
|
|
:param max_norm: Max norm value for gradient clipping
|
|
:type max_norm: float
|
|
"""
|
|
if max_norm > 0.0:
|
|
self.scaler.unscale_(self.optim)
|
|
clip_grad_norm_fp32(model.parameters(), max_norm)
|
|
|
|
|
|
class TorchAMPModel(nn.Module):
|
|
"""A wrapper class for a model object which executes forward with values automatically
|
|
cast to fp16
|
|
"""
|
|
|
|
def __init__(self, model: nn.Module) -> None:
|
|
super().__init__()
|
|
self.model = model
|
|
|
|
@torch_amp.autocast()
|
|
def forward(self, *args, **kwargs):
|
|
return self.model(*args, **kwargs)
|
|
|
|
|
|
class TorchAMPLoss(nn.Module):
|
|
"""A wrapper class for a criterion object which computes the loss in mixed-precision context
|
|
|
|
:param loss: A loss function object
|
|
:type loss: torch.nn.modules.loss._Loss
|
|
"""
|
|
|
|
def __init__(self, loss: _Loss):
|
|
super().__init__()
|
|
self.loss = loss
|
|
|
|
@torch_amp.autocast()
|
|
def forward(self, *args, **kwargs):
|
|
return self.loss(*args, **kwargs)
|