from abc import abstractmethod from enum import Enum import torch import torch.distributed as dist from torch import Tensor from colossalai.accelerator import get_accelerator from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler from .base import MixedPrecisionMixin class OptimState(Enum): SCALED = 0 UNSCALED = 1 class FP16MixedPrecisionMixin(MixedPrecisionMixin): dtype = torch.float16 def __init__( self, initial_scale: float = 2**16, min_scale: float = 1, growth_factor: float = 2, backoff_factor: float = 0.5, growth_interval: int = 1000, hysteresis: int = 2, max_scale: float = 2**32, ) -> None: super().__init__() self.grad_scaler = DynamicGradScaler( initial_scale=initial_scale, min_scale=min_scale, growth_factor=growth_factor, backoff_factor=backoff_factor, growth_interval=growth_interval, hysteresis=hysteresis, max_scale=max_scale, ) self.optim_state = OptimState.UNSCALED self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_accelerator().get_current_device()) @property def loss_scale(self) -> float: return self.grad_scaler.scale.item() @abstractmethod def check_local_overflow(self) -> bool: """Check whether there is overflow in the local process. This method should be implemented by subclasses. Returns: bool: Whether there is overflow in the local process. """ def check_overflow(self) -> bool: # clear previous overflow record self.found_overflow.fill_(0.0) if self.check_local_overflow(): self.found_overflow.fill_(1.0) dist.all_reduce(self.found_overflow, op=dist.ReduceOp.MAX) return self.found_overflow.item() > 0 def pre_backward(self, loss: Tensor) -> Tensor: loss = self.loss_scale * loss self.optim_state = OptimState.SCALED return loss def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor: self.optim_state = OptimState.SCALED return grad def should_skip_step(self) -> bool: found_inf = self.check_overflow() self.grad_scaler.update(found_inf) if found_inf: self.optim_state = OptimState.UNSCALED return found_inf def pre_zero_grad(self) -> None: pass def get_grad_div_scale(self) -> float: assert self.optim_state == OptimState.SCALED, "grads should be scaled before clipping" self.optim_state = OptimState.UNSCALED return self.loss_scale