You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/booster/mixed_precision/fp16_torch.py

139 lines
4.7 KiB

from typing import Callable, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch import Tensor
from torch.optim import Optimizer
from colossalai.accelerator import get_accelerator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from .mixed_precision_base import MixedPrecision
__all__ = ["FP16_Torch_MixedPrecision", "TorchAMPOptimizer", "TorchAMPModule"]
class TorchAMPOptimizer(OptimizerWrapper):
"""
Optimizer wrapper for mixed precision training in FP16 using PyTorch AMP.
Args:
optim (Optimizer): Optimizer to wrap.
init_scale (float): Initial scale factor. Default: 2**16.
growth_factor (float): Factor by which the scale is multiplied during
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be finite
this iteration. Default: 2.0.
backoff_factor (float): Factor by which the scale is multiplied during
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be infinite
this iteration. Default: 0.5.
growth_interval (int): Number of iterations between :meth:`torch.cuda.amp.GradScaler.step`
calls that may cause the scale to increase. Default: 2000.
"""
def __init__(
self,
optim: Optimizer,
init_scale: float = 2.0**16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
) -> None:
super().__init__(optim)
self.scaler = torch.cuda.amp.GradScaler(
init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
)
def backward(self, loss: Tensor, *args, **kwargs) -> None:
scaled_loss = self.scale_loss(loss)
scaled_loss.backward(*args, **kwargs)
def step(self, *args, **kwargs) -> Optional[float]:
out = self.scaler.step(self.optim, *args, **kwargs)
self.scaler.update()
return out
def scale_loss(self, loss: Tensor) -> Tensor:
return self.scaler.scale(loss)
def unscale_grad(self) -> None:
self.scaler.unscale_(self.optim)
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
self.unscale_grad()
super().clip_grad_by_value(clip_value, *args, **kwargs)
def clip_grad_by_norm(
self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
error_if_nonfinite: bool = False,
*args,
**kwargs,
) -> None:
self.unscale_grad()
super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs)
class TorchAMPModule(ModelWrapper):
"""
Module wrapper for mixed precision training in FP16 using PyTorch AMP.
Args:
module (nn.Module): Module to wrap.
"""
def __init__(self, module: nn.Module):
super().__init__(module)
def forward(self, *args, **kwargs):
with get_accelerator().autocast():
return self.module(*args, **kwargs)
class FP16TorchMixedPrecision(MixedPrecision):
"""
Precision for mixed precision training in FP16 using PyTorch AMP.
Args:
init_scale (float): Initial scale factor. Default: 2**16.
growth_factor (float): Factor by which the scale is multiplied during
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be finite
this iteration. Default: 2.0.
backoff_factor (float): Factor by which the scale is multiplied during
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be infinite
this iteration. Default: 0.5.
growth_interval (int): Number of iterations between :meth:`torch.cuda.amp.GradScaler.step`
calls that may cause the scale to increase. Default: 2000.
"""
def __init__(
self,
init_scale: float = 2.0**16,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 2000,
) -> None:
super().__init__()
self.torch_amp_kwargs = dict(
init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
)
def configure(
self,
model: nn.Module,
optimizer: Optional[Optimizer] = None,
criterion: Optional[Callable] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable]:
model = TorchAMPModule(model)
if optimizer is not None:
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs)
if criterion is not None:
criterion = TorchAMPModule(criterion)
return model, optimizer, criterion