#!/usr/bin/env python # -*- encoding: utf-8 -*- import torch from abc import ABC, abstractmethod from colossalai.logging import get_dist_logger from torch import Tensor from typing import Dict __all__ = ['BaseGradScaler'] class BaseGradScaler(ABC): """A base class for the gradient scaler. Args: initial_scale (float): the initial loss scale verbose (bool): whether to log messages """ def __init__(self, initial_scale: float, verbose: bool): assert initial_scale > 0 self._scale = torch.cuda.FloatTensor([initial_scale]) self._verbose = verbose if self._verbose: self._logger = get_dist_logger() @property def scale(self) -> Tensor: """Returns the loss scale. """ return self._scale @property def inv_scale(self) -> Tensor: """Returns the inverse of the loss scale. """ return self._scale.double().reciprocal().float() def state_dict(self) -> Dict: """Returns the states of the gradient scaler as a dict object. """ state_dict = dict() state_dict['scale'] = self.scale return state_dict def load_state_dict(self, state_dict: Dict) -> None: """Load the states of the gradient scaler from a dict object. Args: state_dict (dict): the states of the gradient scaler """ self._scale = state_dict['scale'] @abstractmethod def update(self, overflow: bool) -> None: """Update the loss scale. Args: overflow (bool): whether overflow occurs """ pass def log(self, message, *args, **kwargs): """Log messages. Args: message (str): the message to log *args: positional arguments for :class:`colossalai.logging.DistributedLogger` **kwargs: key-word arguments for :class:`colossalai.logging.DistributedLogger` """ if self._verbose: self._logger.info(message, *args, **kwargs)