ColossalAI/colossalai/booster/mixed_precision/fp16_torch.py

123 lines
4.8 KiB
Python

from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from ..interface import OptimizerWrapper
from .mixed_precision_base import MixedPrecision
__all__ = ['FP16_Torch_MixedPrecision', 'TorchAMPOptimizer', 'TorchAMPModule']
class TorchAMPOptimizer(OptimizerWrapper):
"""
Optimizer wrapper for mixed precision training in FP16 using PyTorch AMP.
Args:
optim (Optimizer): Optimizer to wrap.
init_scale (float): Initial scale factor. Default: 2**16.
growth_factor (float): Factor by which the scale is multiplied during
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be finite
this iteration. Default: 2.0.
backoff_factor (float): Factor by which the scale is multiplied during
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be infinite
this iteration. Default: 0.5.
growth_interval (int): Number of iterations between :meth:`torch.cuda.amp.GradScaler.step`
calls that may cause the scale to increase. Default: 2000.
"""
def __init__(self,
optim: Optimizer,
init_scale: float = 2.**16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000) -> None:
super().__init__(optim)
self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval)
def backward(self, loss: Tensor, *args, **kwargs) -> None:
scaled_loss = self.scale_loss(loss)
scaled_loss.backward(*args, **kwargs)
def step(self, *args, **kwargs) -> Optional[float]:
return self.scaler.step(self.optim, *args, **kwargs)
def scale_loss(self, loss: Tensor) -> Tensor:
return self.scaler.scale(loss)
def unscale_grad(self) -> None:
self.scaler.unscale_(self.optim)
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
self.unscale_grad()
super().clip_grad_by_value(clip_value, *args, **kwargs)
def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> None:
self.unscale_grad()
super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
class TorchAMPModule(nn.Module):
"""
Module wrapper for mixed precision training in FP16 using PyTorch AMP.
Args:
module (nn.Module): Module to wrap.
"""
def __init__(self, module: nn.Module):
super().__init__()
self.module = module
def forward(self, *args, **kwargs):
with torch.cuda.amp.autocast():
return self.module(*args, **kwargs)
class FP16TorchMixedPrecision(MixedPrecision):
"""
Precision for mixed precision training in FP16 using PyTorch AMP.
Args:
init_scale (float): Initial scale factor. Default: 2**16.
growth_factor (float): Factor by which the scale is multiplied during
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be finite
this iteration. Default: 2.0.
backoff_factor (float): Factor by which the scale is multiplied during
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be infinite
this iteration. Default: 0.5.
growth_interval (int): Number of iterations between :meth:`torch.cuda.amp.GradScaler.step`
calls that may cause the scale to increase. Default: 2000.
"""
def __init__(self,
init_scale: float = 2.**16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000) -> None:
super().__init__()
self.torch_amp_kwargs = dict(init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval)
def configure(self,
model: nn.Module,
optimizer: Optimizer,
criterion: Callable = None) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
model = TorchAMPModule(model)
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
if criterion is not None:
criterion = TorchAMPModule(criterion)
return model, optimizer, criterion