mirror of https://github.com/hpcaitech/ColossalAI
123 lines
4.8 KiB
Python
123 lines
4.8 KiB
Python
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import Tensor
|
|
from torch.optim import Optimizer
|
|
|
|
from ..interface import OptimizerWrapper
|
|
from .mixed_precision_base import MixedPrecision
|
|
|
|
__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule']
|
|
|
|
|
|
class TorchAMPOptimizer(OptimizerWrapper):
|
|
"""
|
|
Optimizer wrapper for mixed precision training in FP16 using PyTorch AMP.
|
|
|
|
Args:
|
|
optim (Optimizer): Optimizer to wrap.
|
|
init_scale (float): Initial scale factor. Default: 2**16.
|
|
growth_factor (float): Factor by which the scale is multiplied during
|
|
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be finite
|
|
this iteration. Default: 2.0.
|
|
backoff_factor (float): Factor by which the scale is multiplied during
|
|
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be infinite
|
|
this iteration. Default: 0.5.
|
|
growth_interval (int): Number of iterations between :meth:`torch.cuda.amp.GradScaler.step`
|
|
calls that may cause the scale to increase. Default: 2000.
|
|
"""
|
|
|
|
def __init__(self,
|
|
optim: Optimizer,
|
|
init_scale: float = 2.**16,
|
|
growth_factor: float = 2.0,
|
|
backoff_factor: float = 0.5,
|
|
growth_interval: int = 2000) -> None:
|
|
super().__init__(optim)
|
|
self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale,
|
|
growth_factor=growth_factor,
|
|
backoff_factor=backoff_factor,
|
|
growth_interval=growth_interval)
|
|
|
|
def backward(self, loss: Tensor, *args, **kwargs) -> None:
|
|
scaled_loss = self.scale_loss(loss)
|
|
scaled_loss.backward(*args, **kwargs)
|
|
|
|
def step(self, *args, **kwargs) -> Optional[float]:
|
|
return self.scaler.step(self.optim, *args, **kwargs)
|
|
|
|
def scale_loss(self, loss: Tensor) -> Tensor:
|
|
return self.scaler.scale(loss)
|
|
|
|
def unscale_grad(self) -> None:
|
|
self.scaler.unscale_(self.optim)
|
|
|
|
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
|
|
self.unscale_grad()
|
|
super().clip_grad_by_value(clip_value, *args, **kwargs)
|
|
|
|
def clip_grad_by_norm(self,
|
|
max_norm: Union[float, int],
|
|
norm_type: Union[float, int] = 2.0,
|
|
error_if_nonfinite: bool = False,
|
|
*args,
|
|
**kwargs) -> None:
|
|
self.unscale_grad()
|
|
super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
|
|
|
|
|
|
class TorchAMPModule(nn.Module):
|
|
"""
|
|
Module wrapper for mixed precision training in FP16 using PyTorch AMP.
|
|
|
|
Args:
|
|
module (nn.Module): Module to wrap.
|
|
"""
|
|
|
|
def __init__(self, module: nn.Module):
|
|
super().__init__()
|
|
self.module = module
|
|
|
|
def forward(self, *args, **kwargs):
|
|
with torch.cuda.amp.autocast():
|
|
return self.module(*args, **kwargs)
|
|
|
|
|
|
class FP16TorchMixedPrecision(MixedPrecision):
|
|
"""
|
|
Precision for mixed precision training in FP16 using PyTorch AMP.
|
|
|
|
Args:
|
|
init_scale (float): Initial scale factor. Default: 2**16.
|
|
growth_factor (float): Factor by which the scale is multiplied during
|
|
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be finite
|
|
this iteration. Default: 2.0.
|
|
backoff_factor (float): Factor by which the scale is multiplied during
|
|
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be infinite
|
|
this iteration. Default: 0.5.
|
|
growth_interval (int): Number of iterations between :meth:`torch.cuda.amp.GradScaler.step`
|
|
calls that may cause the scale to increase. Default: 2000.
|
|
"""
|
|
|
|
def __init__(self,
|
|
init_scale: float = 2.**16,
|
|
growth_factor: float = 2.0,
|
|
backoff_factor: float = 0.5,
|
|
growth_interval: int = 2000) -> None:
|
|
super().__init__()
|
|
self.torch_amp_kwargs = dict(init_scale=init_scale,
|
|
growth_factor=growth_factor,
|
|
backoff_factor=backoff_factor,
|
|
growth_interval=growth_interval)
|
|
|
|
def configure(self,
|
|
model: nn.Module,
|
|
optimizer: Optimizer,
|
|
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
|
|
model = TorchAMPModule(model)
|
|
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
|
|
if criterion is not None:
|
|
criterion = TorchAMPModule(criterion)
|
|
return model, optimizer, criterion
|