mirror of https://github.com/hpcaitech/ColossalAI
refactored grad scaler (#338)
parent
6a3188167c
commit
3d5d64bd10
|
@ -0,0 +1,5 @@
|
|||
from .base_grad_scaler import BaseGradScaler
|
||||
from .constant_grad_scaler import ConstantGradScaler
|
||||
from .dynamic_grad_scaler import DynamicGradScaler
|
||||
|
||||
__all__ = ['BaseGradScaler', 'ConstantGradScaler', 'DynamicGradScaler']
|
|
@ -0,0 +1,46 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from abc import ABC, abstractmethod
|
||||
from colossalai.logging import get_dist_logger
|
||||
from torch import Tensor
|
||||
from typing import Dict
|
||||
|
||||
__all__ = ['BaseGradScaler']
|
||||
|
||||
|
||||
class BaseGradScaler(ABC):
|
||||
|
||||
def __init__(self, initial_scale: int, verbose: bool):
|
||||
assert initial_scale > 0
|
||||
self._scale = torch.cuda.FloatTensor([initial_scale])
|
||||
self._verbose = verbose
|
||||
|
||||
if self._verbose:
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
@property
|
||||
def scale(self) -> Tensor:
|
||||
return self._scale
|
||||
|
||||
@property
|
||||
def inv_scale(self) -> Tensor:
|
||||
return self._scale.double().reciprocal().float()
|
||||
|
||||
@abstractmethod
|
||||
def state_dict(self) -> Dict:
|
||||
state_dict = dict()
|
||||
state_dict['scale'] = self.scale
|
||||
|
||||
@abstractmethod
|
||||
def load_state_dict(self, state_dict: Dict) -> None:
|
||||
self._scale = state_dict['scale']
|
||||
|
||||
@abstractmethod
|
||||
def update(self, overflow: bool) -> None:
|
||||
pass
|
||||
|
||||
def log(self, message, *args, **kwargs):
|
||||
if self._verbose:
|
||||
self._logger.info(message, *args, **kwargs)
|
|
@ -0,0 +1,16 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
from .base_grad_scaler import BaseGradScaler
|
||||
|
||||
__all__ = ['ConstantGradScaler']
|
||||
|
||||
|
||||
class ConstantGradScaler(BaseGradScaler):
|
||||
|
||||
def __init__(self, initial_scale: int, verbose: bool):
|
||||
super().__init__(initial_scale, verbose)
|
||||
self.log(f"Constant Gradient Scaler is initialized with scale {self.scale}", ranks=[0])
|
||||
|
||||
def update(self, overflow: bool) -> None:
|
||||
# do nothing to maintain the current scale value
|
||||
pass
|
|
@ -0,0 +1,68 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
from .base_grad_scaler import BaseGradScaler
|
||||
|
||||
__all__ = ['DynamicGradScaler']
|
||||
|
||||
|
||||
class DynamicGradScaler(BaseGradScaler):
|
||||
|
||||
def __init__(self,
|
||||
initial_scale: int = 2**16,
|
||||
growth_factor: int = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
min_scale: int = None,
|
||||
max_scale: int = None,
|
||||
hysteresis: int = None,
|
||||
verbose: bool = False):
|
||||
super().__init__(initial_scale, verbose)
|
||||
self._min_scale = min_scale
|
||||
self._max_scale = max_scale
|
||||
self._growth_factor = growth_factor
|
||||
self._backoff_factor = backoff_factor
|
||||
self._growth_interval = growth_interval
|
||||
self._growth_step = 0
|
||||
self._hysteresis = hysteresis
|
||||
self._hysteresis_step = 0
|
||||
self._sanity_checks()
|
||||
|
||||
def _sanity_checks(self) -> None:
|
||||
if self._min_scale:
|
||||
assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative'
|
||||
if self._max_scale:
|
||||
assert self._min_scale > 0, 'The maximum gradient scale cannot be zero or negative'
|
||||
assert self._growth_factor > 1, 'The growth factor cannot be equal or smaller than 1'
|
||||
assert self._backoff_factor < 1 and self._backoff_factor > 0, 'The backoff factor must be between 0 and 1'
|
||||
assert self._hysteresis >= 0, 'The hysteresis cannot be negative'
|
||||
|
||||
def update(self, overflow: bool) -> None:
|
||||
if overflow:
|
||||
self._hysteresis_step += 1
|
||||
self._growth_step = 0
|
||||
|
||||
if self._hysteresis_step >= self._hysteresis:
|
||||
self._backoff_scale()
|
||||
self.log(f"Overflow occurs, the loss scale is adjusted to {self.scale.item()}", ranks=[0])
|
||||
else:
|
||||
self._growth_step += 1
|
||||
if self._growth_step == self._growth_interval:
|
||||
self._growth_step = 0
|
||||
self._hysteresis_step = 0
|
||||
self._grow_scale()
|
||||
self.log(
|
||||
f"No overflow for consecutive {self._growth_interval} steps, "
|
||||
f"the loss scale is adjusted to {self.scale.item()}",
|
||||
ranks=[0])
|
||||
|
||||
def _backoff_scale(self) -> None:
|
||||
self._scale = self._scale * self._backoff_factor
|
||||
if self._min_scale:
|
||||
self._scale = torch.max(self._scale, self._min_scale)
|
||||
|
||||
def _grow_scale(self) -> None:
|
||||
self._scale = self._scale * self._growth_factor
|
||||
if self._max_scale:
|
||||
self._scale = torch.min(self._scale, self._max_scale)
|
Loading…
Reference in New Issue