mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
138 lines
4.7 KiB
138 lines
4.7 KiB
from typing import Callable, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor |
|
from torch.optim import Optimizer |
|
|
|
from colossalai.interface import ModelWrapper, OptimizerWrapper |
|
from colossalai.utils.device import autocast |
|
|
|
from .mixed_precision_base import MixedPrecision |
|
|
|
__all__ = ["FP16_Torch_MixedPrecision", "TorchAMPOptimizer", "TorchAMPModule"] |
|
|
|
|
|
class TorchAMPOptimizer(OptimizerWrapper): |
|
""" |
|
Optimizer wrapper for mixed precision training in FP16 using PyTorch AMP. |
|
|
|
Args: |
|
optim (Optimizer): Optimizer to wrap. |
|
init_scale (float): Initial scale factor. Default: 2**16. |
|
growth_factor (float): Factor by which the scale is multiplied during |
|
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be finite |
|
this iteration. Default: 2.0. |
|
backoff_factor (float): Factor by which the scale is multiplied during |
|
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be infinite |
|
this iteration. Default: 0.5. |
|
growth_interval (int): Number of iterations between :meth:`torch.cuda.amp.GradScaler.step` |
|
calls that may cause the scale to increase. Default: 2000. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
optim: Optimizer, |
|
init_scale: float = 2.0**16, |
|
growth_factor: float = 2.0, |
|
backoff_factor: float = 0.5, |
|
growth_interval: int = 2000, |
|
) -> None: |
|
super().__init__(optim) |
|
self.scaler = torch.cuda.amp.GradScaler( |
|
init_scale=init_scale, |
|
growth_factor=growth_factor, |
|
backoff_factor=backoff_factor, |
|
growth_interval=growth_interval, |
|
) |
|
|
|
def backward(self, loss: Tensor, *args, **kwargs) -> None: |
|
scaled_loss = self.scale_loss(loss) |
|
scaled_loss.backward(*args, **kwargs) |
|
|
|
def step(self, *args, **kwargs) -> Optional[float]: |
|
out = self.scaler.step(self.optim, *args, **kwargs) |
|
self.scaler.update() |
|
return out |
|
|
|
def scale_loss(self, loss: Tensor) -> Tensor: |
|
return self.scaler.scale(loss) |
|
|
|
def unscale_grad(self) -> None: |
|
self.scaler.unscale_(self.optim) |
|
|
|
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None: |
|
self.unscale_grad() |
|
super().clip_grad_by_value(clip_value, *args, **kwargs) |
|
|
|
def clip_grad_by_norm( |
|
self, |
|
max_norm: Union[float, int], |
|
norm_type: Union[float, int] = 2.0, |
|
error_if_nonfinite: bool = False, |
|
*args, |
|
**kwargs, |
|
) -> None: |
|
self.unscale_grad() |
|
super().clip_grad_by_norm(max_norm, norm_type, error_if_nonfinite, *args, **kwargs) |
|
|
|
|
|
class TorchAMPModule(ModelWrapper): |
|
""" |
|
Module wrapper for mixed precision training in FP16 using PyTorch AMP. |
|
|
|
Args: |
|
module (nn.Module): Module to wrap. |
|
""" |
|
|
|
def __init__(self, module: nn.Module): |
|
super().__init__(module) |
|
|
|
def forward(self, *args, **kwargs): |
|
with autocast(): |
|
return self.module(*args, **kwargs) |
|
|
|
|
|
class FP16TorchMixedPrecision(MixedPrecision): |
|
""" |
|
Precision for mixed precision training in FP16 using PyTorch AMP. |
|
|
|
Args: |
|
init_scale (float): Initial scale factor. Default: 2**16. |
|
growth_factor (float): Factor by which the scale is multiplied during |
|
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be finite |
|
this iteration. Default: 2.0. |
|
backoff_factor (float): Factor by which the scale is multiplied during |
|
:meth:`torch.cuda.amp.GradScaler.step` if gradients were found to be infinite |
|
this iteration. Default: 0.5. |
|
growth_interval (int): Number of iterations between :meth:`torch.cuda.amp.GradScaler.step` |
|
calls that may cause the scale to increase. Default: 2000. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
init_scale: float = 2.0**16, |
|
growth_factor: float = 2.0, |
|
backoff_factor: float = 0.5, |
|
growth_interval: int = 2000, |
|
) -> None: |
|
super().__init__() |
|
self.torch_amp_kwargs = dict( |
|
init_scale=init_scale, |
|
growth_factor=growth_factor, |
|
backoff_factor=backoff_factor, |
|
growth_interval=growth_interval, |
|
) |
|
|
|
def configure( |
|
self, |
|
model: nn.Module, |
|
optimizer: Optional[Optimizer] = None, |
|
criterion: Optional[Callable] = None, |
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable]: |
|
model = TorchAMPModule(model) |
|
if optimizer is not None: |
|
optimizer = TorchAMPOptimizer(optimizer, **self.torch_amp_kwargs) |
|
if criterion is not None: |
|
criterion = TorchAMPModule(criterion) |
|
return model, optimizer, criterion
|
|
|