ColossalAI/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py

85 lines
2.8 KiB
Python
Raw Normal View History

from abc import abstractmethod
from enum import Enum
import torch
import torch.distributed as dist
from torch import Tensor
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.utils import get_current_device
from .base import MixedPrecisionMixin
class OptimState(Enum):
SCALED = 0
UNSCALED = 1
class FP16MixedPrecisionMixin(MixedPrecisionMixin):
dtype = torch.float16
def __init__(self,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32) -> None:
super().__init__()
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
self.optim_state = OptimState.UNSCALED
self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device())
@property
def loss_scale(self) -> float:
return self.grad_scaler.scale.item()
@abstractmethod
def check_local_overflow(self) -> bool:
"""Check whether there is overflow in the local process. This method should be implemented by subclasses.
Returns:
bool: Whether there is overflow in the local process.
"""
pass
def check_overflow(self) -> bool:
# clear previous overflow record
self.found_overflow.fill_(0.0)
if self.check_local_overflow():
self.found_overflow.fill_(1.0)
dist.all_reduce(self.found_overflow, op=dist.ReduceOp.MAX)
return self.found_overflow.item() > 0
def pre_backward(self, loss: Tensor) -> Tensor:
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
return loss
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
self.optim_state = OptimState.SCALED
return grad
def should_skip_step(self) -> bool:
found_inf = self.check_overflow()
self.grad_scaler.update(found_inf)
if found_inf:
self.optim_state = OptimState.UNSCALED
return found_inf
def pre_zero_grad(self) -> None:
pass
def get_grad_div_scale(self) -> float:
assert self.optim_state == OptimState.SCALED, 'grads should be scaled before clipping'
self.optim_state = OptimState.UNSCALED
return self.loss_scale