[doc] improved docstring in the amp module (#857)

pull/867/head
Frank Lee 2022-04-25 13:42:17 +08:00 committed by GitHub
parent b862d89d00
commit 9fdebadd69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 147 additions and 10 deletions

View File

@ -11,6 +11,9 @@ def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config):
optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for initializing apex_amp.
Returns:
Tuple: A tuple (model, optimizer).
The ``amp_config`` should include parameters below:
::
@ -27,9 +30,6 @@ def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config):
min_loss_scale (float, default=None)
max_loss_scale (float, default=2.**24)
Returns:
Tuples: A tuple (model, optimizer).
More details about ``amp_config`` refer to `amp_config <https://nvidia.github.io/apex/amp.html?highlight=apex%20amp>`_.
"""
import apex.amp as apex_amp

View File

@ -28,7 +28,7 @@ class ApexAMPOptimizer(ColossalaiOptimizer):
scaled_loss.backward()
def clip_grad_norm(self, model: nn.Module, max_norm: float):
"""Clip gradients' norm
"""Clip gradients by norm
Args:
model (torch.nn.Module): Your model object

View File

@ -17,6 +17,8 @@ def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
optimizer (:class:`torch.optim.Optimizer`): your optimizer object
amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp.
Returns:
Tuple: A tuple (model, optimizer)
The ``amp_config`` should contain parameters below::
@ -24,9 +26,6 @@ def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0).
Note that clipping is ignored if clip_grad == 0.
dynamic_grad_scale (bool): whether to use dynamic grad scaler.
Returns:
Tuples: A tuple (model, optimizer)
"""
if isinstance(model, nn.ModuleList):
# interleaved pipeline

View File

@ -152,18 +152,39 @@ class FP16Optimizer(Optimizer):
@property
def grad_scaler(self):
"""Returns the gradient scaler.
Returns:
:class:`BaseGradScaler`: gradient scaler.
"""
return self._grad_scaler
@property
def loss_scale(self):
"""Returns the loss scale.
Returns:
int: loss scale.
"""
return self._grad_scaler.scale
@property
def optimizer(self):
"""Returns the optimizer.
Returns:
:class:`torch.optim.Optimizer`: the optimizer object wrapped.
"""
return self._optimizer
@property
def defaults(self):
"""Returns the default arguments of optimizer.
Returns:
dict: optimizer arguments saved in defaults of the optimizer wrapped.
"""
return self._defaults
def _check_overflow(self):
@ -188,6 +209,12 @@ class FP16Optimizer(Optimizer):
return self._found_overflow.item() > 0
def zero_grad(self, set_to_none=True):
"""Set gradient to zero.
Args:
set_to_none (bool): Whether set the gradient to None.
"""
# set_to_none = True can save some memory space
for param_group in self._optimizer.param_groups:
zero_gard_by_list(param_group['params'], set_to_none=set_to_none)
@ -222,6 +249,9 @@ class FP16Optimizer(Optimizer):
overflow_buf=self._dummy_overflow_buf)
def step(self):
"""Update the model parameters.
"""
# Copy gradients from model params to main params.
self._assign_grad_to_fp32_master_param()
self._unscale_grads()
@ -248,10 +278,19 @@ class FP16Optimizer(Optimizer):
return True, grad_norm
def backward(self, loss):
"""Execute backward pass.
Args:
loss (:class:`torch.Tensor`): the loss value.
"""
scaled_loss = loss * self.grad_scaler.scale
scaled_loss.backward()
def state_dict(self):
"""Returns the states of the fp16 optimizer as a dict object.
"""
state_dict = {}
state_dict['optimizer'] = self._optimizer.state_dict()
if self.grad_scaler:
@ -260,6 +299,12 @@ class FP16Optimizer(Optimizer):
return state_dict
def load_state_dict(self, state_dict):
"""Load the states of the fp16 optimizer from a dict object.
Args:
state_dict (dict): the states of the fp16 optimizer
"""
# Optimizer.
self._optimizer.load_state_dict(state_dict['optimizer'])
@ -275,6 +320,11 @@ class FP16Optimizer(Optimizer):
current_param.data.copy_(ckpt_param.data)
def clip_grad_norm(self, clip_grad):
"""Clip gradients by norm.
Args:
clip_grad (float): the max norm for clipping
"""
params = []
for param_group in self._optimizer.param_groups:
for param in param_group['params']:

View File

@ -3,6 +3,14 @@ from torch import Tensor
def has_inf_or_nan(tensor):
"""Check if tensor has inf or nan values.
Args:
tensor (:class:`torch.Tensor`): a torch tensor object
Returns:
bool: Whether the tensor has inf or nan. True for yes and False for no.
"""
try:
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as tensor
@ -24,8 +32,8 @@ def has_inf_or_nan(tensor):
def zero_gard_by_list(tensor_list: List[Tensor], set_to_none: bool = True) -> None:
"""
Clear the gradient of a list of tensors,
"""Clear the gradient of a list of tensors,
Note: copied from torch.optim.optimizer.
"""
for param in tensor_list:

View File

@ -11,6 +11,12 @@ __all__ = ['BaseGradScaler']
class BaseGradScaler(ABC):
"""A base class for the gradient scaler.
Args:
initial_scale (float): the initial loss scale
verbose (bool): whether to log messages
"""
def __init__(self, initial_scale: float, verbose: bool):
assert initial_scale > 0
@ -22,24 +28,53 @@ class BaseGradScaler(ABC):
@property
def scale(self) -> Tensor:
"""Returns the loss scale.
"""
return self._scale
@property
def inv_scale(self) -> Tensor:
"""Returns the inverse of the loss scale.
"""
return self._scale.double().reciprocal().float()
def state_dict(self) -> Dict:
"""Returns the states of the gradient scaler as a dict object.
"""
state_dict = dict()
state_dict['scale'] = self.scale
return state_dict
def load_state_dict(self, state_dict: Dict) -> None:
"""Load the states of the gradient scaler from a dict object.
Args:
state_dict (dict): the states of the gradient scaler
"""
self._scale = state_dict['scale']
@abstractmethod
def update(self, overflow: bool) -> None:
"""Update the loss scale.
Args:
overflow (bool): whether overflow occurs
"""
pass
def log(self, message, *args, **kwargs):
"""Log messages.
Args:
message (str): the message to log
*args: positional arguments for :class:`colossalai.logging.DistributedLogger`
**kwargs: key-word arguments for :class:`colossalai.logging.DistributedLogger`
"""
if self._verbose:
self._logger.info(message, *args, **kwargs)

View File

@ -6,11 +6,21 @@ __all__ = ['ConstantGradScaler']
class ConstantGradScaler(BaseGradScaler):
"""A gradient scaler which uses constant loss scale
Args:
initial_scale (float): the initial loss scale
verbose (bool): whether to log messages
"""
def __init__(self, initial_scale: int, verbose: bool):
super().__init__(initial_scale, verbose)
self.log(f"Constant Gradient Scaler is initialized with scale {self.scale}", ranks=[0])
def update(self, overflow: bool) -> None:
# do nothing to maintain the current scale value
"""Do nothing to keep the loss scale constant.
Args:
overflow (bool): whether overflow occurs
"""
pass

View File

@ -9,6 +9,18 @@ __all__ = ['DynamicGradScaler']
class DynamicGradScaler(BaseGradScaler):
"""A gradient scaler which uses dynamic loss scale
Args:
initial_scale (float): the initial loss scale, defaults to 2**16
growth_factor (float): the multiplication factor for increasing loss scale, defaults to 2
backoff_factor (float): the multiplication factor for decreasing loss scale, defaults to 0.5
growth_interval (int): the number of steps to increase loss scale when no overflow occurs, defaults to 1000
min_scale (float): the minimum loss scale, defaults to None
max_scale (float): the maximum loss scale, defaults to None
hysteresis (int): the number of overflows before decreasing loss scale, defaults to 2
verbose (bool): whether to log messages, defaults to False
"""
def __init__(self,
initial_scale: float = 2**16,
@ -39,6 +51,9 @@ class DynamicGradScaler(BaseGradScaler):
self._sanity_checks()
def _sanity_checks(self) -> None:
"""Check if the arguments are correct.
"""
if self._min_scale:
assert self._min_scale > 0, 'The minimum gradient scale cannot be zero or negative'
if self._max_scale:
@ -48,6 +63,11 @@ class DynamicGradScaler(BaseGradScaler):
assert self._hysteresis >= 0, 'The hysteresis cannot be negative'
def update(self, overflow: bool) -> None:
"""Update the loss scale.
Args:
overflow (bool): whether overflow occurs
"""
if overflow:
self._hysteresis_step += 1
self._growth_step = 0
@ -67,11 +87,17 @@ class DynamicGradScaler(BaseGradScaler):
ranks=[0])
def _backoff_scale(self) -> None:
"""Decrease the loss scale
"""
self._scale = self._scale * self._backoff_factor
if self._min_scale:
self._scale = torch.max(self._scale, self._min_scale)
def _grow_scale(self) -> None:
"""Increase the loss scale
"""
self._scale = self._scale * self._growth_factor
if self._max_scale:
self._scale = torch.min(self._scale, self._max_scale)

View File

@ -62,6 +62,9 @@ class TorchAMPOptimizer(ColossalaiOptimizer):
class TorchAMPModel(nn.Module):
"""A wrapper class for a model object which executes forward with values automatically
cast to fp16
Args:
model (:class:`torch.nn.Module`): a torch model instance
"""
def __init__(self, model: nn.Module) -> None:
@ -70,6 +73,9 @@ class TorchAMPModel(nn.Module):
@torch_amp.autocast()
def forward(self, *args, **kwargs):
"""
Execute forward under the torch amp context
"""
return self.model(*args, **kwargs)
@ -86,4 +92,7 @@ class TorchAMPLoss(nn.Module):
@torch_amp.autocast()
def forward(self, *args, **kwargs):
"""
Execute forward under the torch amp context
"""
return self.loss(*args, **kwargs)