mirror of https://github.com/hpcaitech/ColossalAI
81 lines
2.0 KiB
Python
81 lines
2.0 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import torch
|
|
from abc import ABC, abstractmethod
|
|
from colossalai.logging import get_dist_logger
|
|
from torch import Tensor
|
|
from typing import Dict
|
|
|
|
__all__ = ['BaseGradScaler']
|
|
|
|
|
|
class BaseGradScaler(ABC):
|
|
"""A base class for the gradient scaler.
|
|
|
|
Args:
|
|
initial_scale (float): the initial loss scale
|
|
verbose (bool): whether to log messages
|
|
"""
|
|
|
|
def __init__(self, initial_scale: float, verbose: bool):
|
|
assert initial_scale > 0
|
|
self._scale = torch.cuda.FloatTensor([initial_scale])
|
|
self._verbose = verbose
|
|
|
|
if self._verbose:
|
|
self._logger = get_dist_logger()
|
|
|
|
@property
|
|
def scale(self) -> Tensor:
|
|
"""Returns the loss scale.
|
|
"""
|
|
|
|
return self._scale
|
|
|
|
@property
|
|
def inv_scale(self) -> Tensor:
|
|
"""Returns the inverse of the loss scale.
|
|
"""
|
|
|
|
return self._scale.double().reciprocal().float()
|
|
|
|
def state_dict(self) -> Dict:
|
|
"""Returns the states of the gradient scaler as a dict object.
|
|
"""
|
|
|
|
state_dict = dict()
|
|
state_dict['scale'] = self.scale
|
|
return state_dict
|
|
|
|
def load_state_dict(self, state_dict: Dict) -> None:
|
|
"""Load the states of the gradient scaler from a dict object.
|
|
|
|
Args:
|
|
state_dict (dict): the states of the gradient scaler
|
|
"""
|
|
|
|
self._scale = state_dict['scale']
|
|
|
|
@abstractmethod
|
|
def update(self, overflow: bool) -> None:
|
|
"""Update the loss scale.
|
|
|
|
Args:
|
|
overflow (bool): whether overflow occurs
|
|
"""
|
|
|
|
pass
|
|
|
|
def log(self, message, *args, **kwargs):
|
|
"""Log messages.
|
|
|
|
Args:
|
|
message (str): the message to log
|
|
*args: positional arguments for :class:`colossalai.logging.DistributedLogger`
|
|
**kwargs: key-word arguments for :class:`colossalai.logging.DistributedLogger`
|
|
"""
|
|
|
|
if self._verbose:
|
|
self._logger.info(message, *args, **kwargs)
|