2022-03-09 03:52:43 +00:00
|
|
|
#!/usr/bin/env python
|
|
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
from .base_grad_scaler import BaseGradScaler
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
__all__ = ["ConstantGradScaler"]
|
2022-03-09 03:52:43 +00:00
|
|
|
|
|
|
|
|
|
|
|
class ConstantGradScaler(BaseGradScaler):
|
2022-04-25 05:42:17 +00:00
|
|
|
"""A gradient scaler which uses constant loss scale
|
|
|
|
|
|
|
|
Args:
|
|
|
|
initial_scale (float): the initial loss scale
|
|
|
|
verbose (bool): whether to log messages
|
|
|
|
"""
|
2022-03-09 03:52:43 +00:00
|
|
|
|
|
|
|
def __init__(self, initial_scale: int, verbose: bool):
|
|
|
|
super().__init__(initial_scale, verbose)
|
|
|
|
self.log(f"Constant Gradient Scaler is initialized with scale {self.scale}", ranks=[0])
|
|
|
|
|
|
|
|
def update(self, overflow: bool) -> None:
|
2022-04-25 05:42:17 +00:00
|
|
|
"""Do nothing to keep the loss scale constant.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
overflow (bool): whether overflow occurs
|
|
|
|
"""
|