|
|
|
@ -18,6 +18,17 @@ class TorchAMPOptimizer(ColossalaiOptimizer):
|
|
|
|
|
|
|
|
|
|
:param optim: a normal optimizer like Adam or SGD |
|
|
|
|
:type optim: torch.optim.Optimizer |
|
|
|
|
:param init_scale: Initial scale factor |
|
|
|
|
:type init_scale: float, optional, default=2.**16 |
|
|
|
|
:param growth_factor: Factor by which the scale is multiplied during :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. |
|
|
|
|
:type growth_factor: float, optional, default=2.0 |
|
|
|
|
:param backoff_factor: Factor by which the scale is multiplied during :meth:`update` if inf/NaN gradients occur in an iteration. |
|
|
|
|
:type backoff_factor: float, optional, default=0.5 |
|
|
|
|
:param growth_interval: Number of consecutive iterations without inf/NaN gradients that must occur for the scale to be multiplied by ``growth_factor``. |
|
|
|
|
:type growth_interval: int, optional, default=2000 |
|
|
|
|
:param enabled: If ``False``, disables gradient scaling. :meth:`step` simply invokes the underlying ``optimizer.step()``, and other methods become no-ops. |
|
|
|
|
:type enabled: bool, optional, default=True |
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, optim: Optimizer, *args, **kwargs): |
|
|
|
@ -68,6 +79,7 @@ class TorchAMPLoss(nn.Module):
|
|
|
|
|
:param loss: a loss function object |
|
|
|
|
:type loss: torch.nn.modules.loss._Loss |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def __init__(self, loss: _Loss): |
|
|
|
|
super().__init__() |
|
|
|
|
self.loss = loss |
|
|
|
|