ColossalAI/colossalai/amp/apex_amp/apex_amp.py

40 lines
1.1 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
try:
import apex.amp as apex_amp
except ImportError:
pass
from torch import Tensor
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import clip_grad_norm_fp32
class ApexAMPOptimizer(ColossalaiOptimizer):
""" A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm
methods
"""
def backward(self, loss: Tensor):
"""Backward pass to get all gradients
:param loss: Loss computed by a loss function
:type loss: torch.Tensor
"""
with apex_amp.scale_loss(loss, self.optim) as scaled_loss:
scaled_loss.backward()
def clip_grad_norm(self, model: nn.Module, max_norm: float):
"""Clip gradients' norm
:param model: Your model object
:type model: torch.nn.Module
:param max_norm: The max norm value for gradient clipping
:type max_norm: float
"""
if max_norm > 0:
clip_grad_norm_fp32(apex_amp.master_params(self.optim), max_norm)