You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py

81 lines
2.0 KiB

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from abc import ABC, abstractmethod
from colossalai.logging import get_dist_logger
from torch import Tensor
from typing import Dict
__all__ = ['BaseGradScaler']
class BaseGradScaler(ABC):
"""A base class for the gradient scaler.
Args:
initial_scale (float): the initial loss scale
verbose (bool): whether to log messages
"""
def __init__(self, initial_scale: float, verbose: bool):
assert initial_scale > 0
self._scale = torch.cuda.FloatTensor([initial_scale])
self._verbose = verbose
if self._verbose:
self._logger = get_dist_logger()
@property
def scale(self) -> Tensor:
"""Returns the loss scale.
"""
return self._scale
@property
def inv_scale(self) -> Tensor:
"""Returns the inverse of the loss scale.
"""
return self._scale.double().reciprocal().float()
def state_dict(self) -> Dict:
"""Returns the states of the gradient scaler as a dict object.
"""
state_dict = dict()
state_dict['scale'] = self.scale
return state_dict
def load_state_dict(self, state_dict: Dict) -> None:
"""Load the states of the gradient scaler from a dict object.
Args:
state_dict (dict): the states of the gradient scaler
"""
self._scale = state_dict['scale']
@abstractmethod
def update(self, overflow: bool) -> None:
"""Update the loss scale.
Args:
overflow (bool): whether overflow occurs
"""
pass
def log(self, message, *args, **kwargs):
"""Log messages.
Args:
message (str): the message to log
*args: positional arguments for :class:`colossalai.logging.DistributedLogger`
**kwargs: key-word arguments for :class:`colossalai.logging.DistributedLogger`
"""
if self._verbose:
self._logger.info(message, *args, **kwargs)