AMP docstring/markdown update (#160)

pull/161/head
puck_WCR 2022-01-18 18:33:36 +08:00 committed by GitHub
parent 2499faa2db
commit 9473a1b9c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 15 additions and 3 deletions

View File

@ -6,7 +6,7 @@ from torch.optim import Optimizer
def convert_to_apex_amp(model: nn.Module, def convert_to_apex_amp(model: nn.Module,
optimizer: Optimizer, optimizer: Optimizer,
amp_config): amp_config):
"""A helper function to wrap training components with Torch AMP modules """A helper function to wrap training components with Apex AMP modules
:param model: your model object :param model: your model object
:type model: :class:`torch.nn.Module` :type model: :class:`torch.nn.Module`

View File

@ -8,7 +8,7 @@ from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel
def convert_to_naive_amp(model: nn.Module, def convert_to_naive_amp(model: nn.Module,
optimizer: Optimizer, optimizer: Optimizer,
amp_config): amp_config):
"""A helper function to wrap training components with Torch AMP modules """A helper function to wrap training components with naive AMP modules
:param model: your model object :param model: your model object
:type model: :class:`torch.nn.Module` :type model: :class:`torch.nn.Module`

View File

@ -18,6 +18,17 @@ class TorchAMPOptimizer(ColossalaiOptimizer):
:param optim: a normal optimizer like Adam or SGD :param optim: a normal optimizer like Adam or SGD
:type optim: torch.optim.Optimizer :type optim: torch.optim.Optimizer
:param init_scale: Initial scale factor
:type init_scale: float, optional, default=2.**16
:param growth_factor: Factor by which the scale is multiplied during :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
:type growth_factor: float, optional, default=2.0
:param backoff_factor: Factor by which the scale is multiplied during :meth:`update` if inf/NaN gradients occur in an iteration.
:type backoff_factor: float, optional, default=0.5
:param growth_interval: Number of consecutive iterations without inf/NaN gradients that must occur for the scale to be multiplied by ``growth_factor``.
:type growth_interval: int, optional, default=2000
:param enabled: If ``False``, disables gradient scaling. :meth:`step` simply invokes the underlying ``optimizer.step()``, and other methods become no-ops.
:type enabled: bool, optional, default=True
""" """
def __init__(self, optim: Optimizer, *args, **kwargs): def __init__(self, optim: Optimizer, *args, **kwargs):
@ -68,6 +79,7 @@ class TorchAMPLoss(nn.Module):
:param loss: a loss function object :param loss: a loss function object
:type loss: torch.nn.modules.loss._Loss :type loss: torch.nn.modules.loss._Loss
""" """
def __init__(self, loss: _Loss): def __init__(self, loss: _Loss):
super().__init__() super().__init__()
self.loss = loss self.loss = loss

View File

@ -82,7 +82,7 @@ fp16 = dict(
We leveraged the Megatron-LM implementation to achieve mixed precision training while maintaining compatibility with complex tensor We leveraged the Megatron-LM implementation to achieve mixed precision training while maintaining compatibility with complex tensor
and pipeline parallelism. This AMP mode will cast all operations into fp16. and pipeline parallelism. This AMP mode will cast all operations into fp16.
The following conde block show a config file for this mode. The following code block shows a config file for this mode.
```python ```python
from colossalai.amp import AMP_TYPE from colossalai.amp import AMP_TYPE