Browse Source

AMP docstring/markdown update (#160)

pull/161/head
puck_WCR 3 years ago committed by GitHub
parent
commit
9473a1b9c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      colossalai/amp/apex_amp/__init__.py
  2. 2
      colossalai/amp/naive_amp/__init__.py
  3. 12
      colossalai/amp/torch_amp/torch_amp.py
  4. 2
      docs/amp.md

2
colossalai/amp/apex_amp/__init__.py

@ -6,7 +6,7 @@ from torch.optim import Optimizer
def convert_to_apex_amp(model: nn.Module,
optimizer: Optimizer,
amp_config):
"""A helper function to wrap training components with Torch AMP modules
"""A helper function to wrap training components with Apex AMP modules
:param model: your model object
:type model: :class:`torch.nn.Module`

2
colossalai/amp/naive_amp/__init__.py

@ -8,7 +8,7 @@ from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel
def convert_to_naive_amp(model: nn.Module,
optimizer: Optimizer,
amp_config):
"""A helper function to wrap training components with Torch AMP modules
"""A helper function to wrap training components with naive AMP modules
:param model: your model object
:type model: :class:`torch.nn.Module`

12
colossalai/amp/torch_amp/torch_amp.py

@ -18,6 +18,17 @@ class TorchAMPOptimizer(ColossalaiOptimizer):
:param optim: a normal optimizer like Adam or SGD
:type optim: torch.optim.Optimizer
:param init_scale: Initial scale factor
:type init_scale: float, optional, default=2.**16
:param growth_factor: Factor by which the scale is multiplied during :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
:type growth_factor: float, optional, default=2.0
:param backoff_factor: Factor by which the scale is multiplied during :meth:`update` if inf/NaN gradients occur in an iteration.
:type backoff_factor: float, optional, default=0.5
:param growth_interval: Number of consecutive iterations without inf/NaN gradients that must occur for the scale to be multiplied by ``growth_factor``.
:type growth_interval: int, optional, default=2000
:param enabled: If ``False``, disables gradient scaling. :meth:`step` simply invokes the underlying ``optimizer.step()``, and other methods become no-ops.
:type enabled: bool, optional, default=True
"""
def __init__(self, optim: Optimizer, *args, **kwargs):
@ -68,6 +79,7 @@ class TorchAMPLoss(nn.Module):
:param loss: a loss function object
:type loss: torch.nn.modules.loss._Loss
"""
def __init__(self, loss: _Loss):
super().__init__()
self.loss = loss

2
docs/amp.md

@ -82,7 +82,7 @@ fp16 = dict(
We leveraged the Megatron-LM implementation to achieve mixed precision training while maintaining compatibility with complex tensor
and pipeline parallelism. This AMP mode will cast all operations into fp16.
The following conde block show a config file for this mode.
The following code block shows a config file for this mode.
```python
from colossalai.amp import AMP_TYPE

Loading…
Cancel
Save