[docs] updatad docs of hybrid adam and cpu adam (#552)

pull/558/head
LuGY 3 years ago committed by GitHub
parent 014bac0c49
commit c44d797072
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,9 +1,56 @@
import math
import torch
from colossalai.registry import OPTIMIZERS
@OPTIMIZERS.register_module
class CPUAdam(torch.optim.Optimizer):
"""Implements Adam algorithm.
Supports parameters updating on both GPU and CPU, depanding on the device of paramters.
But the parameters and gradients should on the same device:
* Parameters on CPU and gradients on CPU is allowed.
* Parameters on GPU and gradients on GPU is allowed.
* Parameters on GPU and gradients on CPU is **not** allowed.
Requires ColossalAI to be installed via ``pip install .``.
This version of CPU Adam accelates parameters updating on CPU with SIMD.
Support of AVX2 or AVX512 is required.
The GPU part is implemented in an naive way.
CPU Adam also supports the hybrid precision calculation, eg. fp32 parameters and fp16 gradients.
:class:`colossalai.nn.optimizer.CPUAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
or ``torch.optim.Adam`` with ``adamw_mode=False``
Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
model_params (iterable): iterable of parameters of dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED yet in CPUAdam!
adamw_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
simd_log (boolean, optional): whether to show if you are using SIMD to
accelerate. (default: False)
.. _Adam: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
optimizer_id = 0
# Number of fp32 shards for per parameter
# Param weight, grad, momentum and variance
@ -18,11 +65,6 @@ class CPUAdam(torch.optim.Optimizer):
weight_decay=0,
adamw_mode=True,
simd_log=False):
"""
An implementation equivalent to `torch.optim.Adam`.
The difference is that model_params are sharded parameters belonging to a ShardedModelV2 instance.
The sharded param of model_params can resident on both CPU and CUDA.
"""
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(CPUAdam, self).__init__(model_params, default_args)

@ -72,8 +72,8 @@ class FusedAdam(torch.optim.Optimizer):
else:
raise RuntimeError('FusedAdam requires cuda extensions')
def zero_grad(self):
if self.set_grad_none:
def zero_grad(self, set_to_none=False):
if set_to_none:
for group in self.param_groups:
for p in group['params']:
p.grad = None

@ -1,7 +1,55 @@
import torch
from colossalai.utils import multi_tensor_applier
from colossalai.registry import OPTIMIZERS
@OPTIMIZERS.register_module
class HybridAdam(torch.optim.Optimizer):
"""Implements Adam algorithm.
Supports parameters updating on both GPU and CPU, depanding on the device of paramters.
But the parameters and gradients should on the same device:
* Parameters on CPU and gradients on CPU is allowed.
* Parameters on GPU and gradients on GPU is allowed.
* Parameters on GPU and gradients on CPU is **not** allowed.
Requires ColossalAI to be installed via ``pip install .``
This version of Hybrid Adam is an hybrid of CPUAdam and FusedAdam.
* For parameters updating on CPU, it uses CPUAdam.
* For parameters updating on GPU, it uses FusedAdam.
* Hybird precision calculation of fp16 and fp32 is supported, eg fp32 parameters and fp16 gradients.
:class:`colossalai.nn.optimizer.HybridAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``,
or ``torch.optim.Adam`` with ``adamw_mode=False``
Adam was been proposed in `Adam: A Method for Stochastic Optimization`_.
Arguments:
model_params (iterable): iterable of parameters of dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
(default: False) NOT SUPPORTED yet in CPUAdam!
adamw_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
simd_log (boolean, optional): whether to show if you are using SIMD to
accelerate. (default: False)
.. _Adam: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
optimizer_id = 0
# Number of fp32 shards for per parameter
# Param weight, grad, momentum and variance
@ -16,11 +64,6 @@ class HybridAdam(torch.optim.Optimizer):
weight_decay=0,
adamw_mode=True,
simd_log=False):
"""
An implementation equivalent to `torch.optim.Adam`.
The difference is that model_params are sharded parameters belonging to a ShardedModelV2 instance.
The sharded param of model_params can resident on both CPU and CUDA(fused adam).
"""
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(HybridAdam, self).__init__(model_params, default_args)

@ -0,0 +1,5 @@
colossalai.nn.optimizer.hybrid\_adam
====================================
.. automodule:: colossalai.nn.optimizer.hybrid_adam
:members:

@ -13,5 +13,6 @@ colossalai.nn.optimizer
colossalai.nn.optimizer.fused_adam
colossalai.nn.optimizer.fused_lamb
colossalai.nn.optimizer.fused_sgd
colossalai.nn.optimizer.hybrid_adam
colossalai.nn.optimizer.lamb
colossalai.nn.optimizer.lars

Loading…
Cancel
Save