ColossalAI/colossalai/amp/naive_amp/grad_scaler/base_grad_scaler.py

79 lines
2.1 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
from typing import Dict
import torch
from torch import Tensor
from colossalai.accelerator import get_accelerator
from colossalai.logging import get_dist_logger
__all__ = ["BaseGradScaler"]
class BaseGradScaler(ABC):
"""A base class for the gradient scaler.
Args:
initial_scale (float): the initial loss scale
verbose (bool): whether to log messages
"""
def __init__(self, initial_scale: float, verbose: bool):
assert initial_scale > 0
self._scale = torch.tensor([initial_scale], device=get_accelerator().get_current_device(), dtype=torch.float)
self._verbose = verbose
if self._verbose:
self._logger = get_dist_logger()
@property
def scale(self) -> Tensor:
"""Returns the loss scale."""
return self._scale
@property
def inv_scale(self) -> Tensor:
"""Returns the inverse of the loss scale."""
return self._scale.double().reciprocal().float()
def state_dict(self) -> Dict:
"""Returns the states of the gradient scaler as a dict object."""
state_dict = dict()
state_dict["scale"] = self.scale
return state_dict
def load_state_dict(self, state_dict: Dict) -> None:
"""Load the states of the gradient scaler from a dict object.
Args:
state_dict (dict): the states of the gradient scaler
"""
self._scale = state_dict["scale"]
@abstractmethod
def update(self, overflow: bool) -> None:
"""Update the loss scale.
Args:
overflow (bool): whether overflow occurs
"""
def log(self, message, *args, **kwargs):
"""Log messages.
Args:
message (str): the message to log
*args: positional arguments for :class:`colossalai.logging.DistributedLogger`
**kwargs: key-word arguments for :class:`colossalai.logging.DistributedLogger`
"""
if self._verbose:
self._logger.info(message, *args, **kwargs)