ColossalAI/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py

83 lines
2.0 KiB
Python
Raw Normal View History

2022-03-09 03:52:43 +00:00
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
from typing import Dict
import torch
from torch import Tensor
from colossalai.logging import get_dist_logger
2022-03-09 03:52:43 +00:00
__all__ = ['BaseGradScaler']
class BaseGradScaler(ABC):
"""A base class for the gradient scaler.
Args:
initial_scale (float): the initial loss scale
verbose (bool): whether to log messages
"""
2022-03-09 03:52:43 +00:00
def __init__(self, initial_scale: float, verbose: bool):
2022-03-09 03:52:43 +00:00
assert initial_scale > 0
self._scale = torch.cuda.FloatTensor([initial_scale])
self._verbose = verbose
if self._verbose:
self._logger = get_dist_logger()
@property
def scale(self) -> Tensor:
"""Returns the loss scale.
"""
2022-03-09 03:52:43 +00:00
return self._scale
@property
def inv_scale(self) -> Tensor:
"""Returns the inverse of the loss scale.
"""
2022-03-09 03:52:43 +00:00
return self._scale.double().reciprocal().float()
def state_dict(self) -> Dict:
"""Returns the states of the gradient scaler as a dict object.
"""
2022-03-09 03:52:43 +00:00
state_dict = dict()
state_dict['scale'] = self.scale
return state_dict
2022-03-09 03:52:43 +00:00
def load_state_dict(self, state_dict: Dict) -> None:
"""Load the states of the gradient scaler from a dict object.
Args:
state_dict (dict): the states of the gradient scaler
"""
2022-03-09 03:52:43 +00:00
self._scale = state_dict['scale']
@abstractmethod
def update(self, overflow: bool) -> None:
"""Update the loss scale.
Args:
overflow (bool): whether overflow occurs
"""
2022-03-09 03:52:43 +00:00
pass
def log(self, message, *args, **kwargs):
"""Log messages.
Args:
message (str): the message to log
*args: positional arguments for :class:`colossalai.logging.DistributedLogger`
**kwargs: key-word arguments for :class:`colossalai.logging.DistributedLogger`
"""
2022-03-09 03:52:43 +00:00
if self._verbose:
self._logger.info(message, *args, **kwargs)