refactored grad scaler (#338)

pull/394/head
Frank Lee 2022-03-09 11:52:43 +08:00
parent 6a3188167c
commit 3d5d64bd10
4 changed files with 135 additions and 0 deletions

View File

@ -0,0 +1,5 @@
from .base_grad_scaler import BaseGradScaler
from .constant_grad_scaler import ConstantGradScaler
from .dynamic_grad_scaler import DynamicGradScaler
__all__ = ['BaseGradScaler', 'ConstantGradScaler', 'DynamicGradScaler']

View File

@ -0,0 +1,46 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from abc import ABC, abstractmethod
from colossalai.logging import get_dist_logger
from torch import Tensor
from typing import Dict
__all__ = ['BaseGradScaler']
class BaseGradScaler(ABC):
def __init__(self, initial_scale: int, verbose: bool):
assert initial_scale > 0
self._scale = torch.cuda.FloatTensor([initial_scale])
self._verbose = verbose
if self._verbose:
self._logger = get_dist_logger()
@property
def scale(self) -> Tensor:
return self._scale
@property
def inv_scale(self) -> Tensor:
return self._scale.double().reciprocal().float()
@abstractmethod
def state_dict(self) -> Dict:
state_dict = dict()
state_dict['scale'] = self.scale
@abstractmethod
def load_state_dict(self, state_dict: Dict) -> None:
self._scale = state_dict['scale']
@abstractmethod
def update(self, overflow: bool) -> None:
pass
def log(self, message, *args, **kwargs):
if self._verbose:
self._logger.info(message, *args, **kwargs)

View File

@ -0,0 +1,16 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from .base_grad_scaler import BaseGradScaler
__all__ = ['ConstantGradScaler']
class ConstantGradScaler(BaseGradScaler):
def __init__(self, initial_scale: int, verbose: bool):
super().__init__(initial_scale, verbose)
self.log(f"Constant Gradient Scaler is initialized with scale {self.scale}", ranks=[0])
def update(self, overflow: bool) -> None:
# do nothing to maintain the current scale value
pass

View File

@ -0,0 +1,68 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from .base_grad_scaler import BaseGradScaler
__all__ = ['DynamicGradScaler']
class DynamicGradScaler(BaseGradScaler):
def __init__(self,
initial_scale: int = 2**16,
growth_factor: int = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
min_scale: int = None,
max_scale: int = None,
hysteresis: int = None,
verbose: bool = False):
super().__init__(initial_scale, verbose)
self._min_scale = min_scale
self._max_scale = max_scale
self._growth_factor = growth_factor
self._backoff_factor = backoff_factor
self._growth_interval = growth_interval
self._growth_step = 0
self._hysteresis = hysteresis
self._hysteresis_step = 0
self._sanity_checks()
def _sanity_checks(self) -> None:
if self._min_scale:
assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative'
if self._max_scale:
assert self._min_scale > 0, 'The maximum gradient scale cannot be zero or negative'
assert self._growth_factor > 1, 'The growth factor cannot be equal or smaller than 1'
assert self._backoff_factor < 1 and self._backoff_factor > 0, 'The backoff factor must be between 0 and 1'
assert self._hysteresis >= 0, 'The hysteresis cannot be negative'
def update(self, overflow: bool) -> None:
if overflow:
self._hysteresis_step += 1
self._growth_step = 0
if self._hysteresis_step >= self._hysteresis:
self._backoff_scale()
self.log(f"Overflow occurs, the loss scale is adjusted to {self.scale.item()}", ranks=[0])
else:
self._growth_step += 1
if self._growth_step == self._growth_interval:
self._growth_step = 0
self._hysteresis_step = 0
self._grow_scale()
self.log(
f"No overflow for consecutive {self._growth_interval} steps, "
f"the loss scale is adjusted to {self.scale.item()}",
ranks=[0])
def _backoff_scale(self) -> None:
self._scale = self._scale * self._backoff_factor
if self._min_scale:
self._scale = torch.max(self._scale, self._min_scale)
def _grow_scale(self) -> None:
self._scale = self._scale * self._growth_factor
if self._max_scale:
self._scale = torch.min(self._scale, self._max_scale)