ColossalAI/colossalai/booster/mixed_precision/fp16_torch.py

139 lines
4.7 KiB
Python
Raw Normal View History

from typing import Callable, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from colossalai.accelerator import get_accelerator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from .mixed_precision_base import MixedPrecision
__all__ = ["FP16_Torch_MixedPrecision", "TorchAMPOptimizer", "TorchAMPModule"]
class TorchAMPOptimizer(OptimizerWrapper):
"""
Optimizer wrapper for mixed precision training in FP16 using PyTorch AMP.
Args:
optim (Optimizer): Optimizer to wrap.
init_scale (float): Initial scale factor. Default: 2**16.
growth_factor (float): Factor by which the scale is multiplied during
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be finite
this iteration. Default: 2.0.
backoff_factor (float): Factor by which the scale is multiplied during
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be infinite
this iteration. Default: 0.5.
growth_interval (int): Number of iterations between :meth:`torch.cuda.amp.GradScaler.step`
calls that may cause the scale to increase. Default: 2000.
"""
def __init__(
self,
optim: Optimizer,
init_scale: float = 2.0**16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
) -> None:
super().__init__(optim)
self.scaler = torch.cuda.amp.GradScaler(
init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
)
def backward(self, loss: Tensor, *args, **kwargs) -> None:
scaled_loss = self.scale_loss(loss)
scaled_loss.backward(*args, **kwargs)
def step(self, *args, **kwargs) -> Optional[float]:
out = self.scaler.step(self.optim, *args, **kwargs)
self.scaler.update()
return out
def scale_loss(self, loss: Tensor) -> Tensor:
return self.scaler.scale(loss)
def unscale_grad(self) -> None:
self.scaler.unscale_(self.optim)
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
self.unscale_grad()
super().clip_grad_by_value(clip_value, *args, **kwargs)
def clip_grad_by_norm(
self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = False,
*args,
**kwargs,
) -> None:
self.unscale_grad()
super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
class TorchAMPModule(ModelWrapper):
"""
Module wrapper for mixed precision training in FP16 using PyTorch AMP.
Args:
module (nn.Module): Module to wrap.
"""
def __init__(self, module: nn.Module):
super().__init__(module)
def forward(self, *args, **kwargs):
with get_accelerator().autocast():
return self.module(*args, **kwargs)
class FP16TorchMixedPrecision(MixedPrecision):
"""
Precision for mixed precision training in FP16 using PyTorch AMP.
Args:
init_scale (float): Initial scale factor. Default: 2**16.
growth_factor (float): Factor by which the scale is multiplied during
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be finite
this iteration. Default: 2.0.
backoff_factor (float): Factor by which the scale is multiplied during
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be infinite
this iteration. Default: 0.5.
growth_interval (int): Number of iterations between :meth:`torch.cuda.amp.GradScaler.step`
calls that may cause the scale to increase. Default: 2000.
"""
def __init__(
self,
init_scale: float = 2.0**16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
) -> None:
super().__init__()
self.torch_amp_kwargs = dict(
init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
)
def configure(
self,
model: nn.Module,
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
model = TorchAMPModule(model)
if optimizer is not None:
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
if criterion is not None:
criterion = TorchAMPModule(criterion)
return model, optimizer, criterion