From c44d797072fd184b6035465a473f4cf9922ed134 Mon Sep 17 00:00:00 2001 From: LuGY <74758262+Gy-Lu@users.noreply.github.com> Date: Wed, 30 Mar 2022 18:14:59 +0800 Subject: [PATCH] [docs] updatad docs of hybrid adam and cpu adam (#552) --- colossalai/nn/optimizer/cpu_adam.py | 54 ++++++++++++++++--- colossalai/nn/optimizer/fused_adam.py | 4 +- colossalai/nn/optimizer/hybrid_adam.py | 53 ++++++++++++++++-- .../colossalai.nn.optimizer.hybrid_adam.rst | 5 ++ docs/colossalai/colossalai.nn.optimizer.rst | 1 + 5 files changed, 104 insertions(+), 13 deletions(-) create mode 100644 docs/colossalai/colossalai.nn.optimizer.hybrid_adam.rst diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 616600d29..88cd1cddc 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -1,9 +1,56 @@ import math - import torch +from colossalai.registry import OPTIMIZERS + +@OPTIMIZERS.register_module class CPUAdam(torch.optim.Optimizer): + """Implements Adam algorithm. + + Supports parameters updating on both GPU and CPU, depanding on the device of paramters. + But the parameters and gradients should on the same device: + * Parameters on CPU and gradients on CPU is allowed. + * Parameters on GPU and gradients on GPU is allowed. + * Parameters on GPU and gradients on CPU is **not** allowed. + + Requires ColossalAI to be installed via ``pip install .``. + + This version of CPU Adam accelates parameters updating on CPU with SIMD. + Support of AVX2 or AVX512 is required. + + The GPU part is implemented in an naive way. + + CPU Adam also supports the hybrid precision calculation, eg. fp32 parameters and fp16 gradients. + + :class:`colossalai.nn.optimizer.CPUAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``, + or ``torch.optim.Adam`` with ``adamw_mode=False`` + + Adam was been proposed in `Adam: A Method for Stochastic Optimization`_. + + Arguments: + model_params (iterable): iterable of parameters of dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) NOT SUPPORTED yet in CPUAdam! + adamw_mode (boolean, optional): Apply L2 regularization or weight decay + True for decoupled weight decay(also known as AdamW) (default: True) + simd_log (boolean, optional): whether to show if you are using SIMD to + accelerate. (default: False) + + .. _Adam: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + optimizer_id = 0 # Number of fp32 shards for per parameter # Param weight, grad, momentum and variance @@ -18,11 +65,6 @@ class CPUAdam(torch.optim.Optimizer): weight_decay=0, adamw_mode=True, simd_log=False): - """ - An implementation equivalent to `torch.optim.Adam`. - The difference is that model_params are sharded parameters belonging to a ShardedModelV2 instance. - The sharded param of model_params can resident on both CPU and CUDA. - """ default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) super(CPUAdam, self).__init__(model_params, default_args) diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py index 465e000a1..89ca3a8c6 100644 --- a/colossalai/nn/optimizer/fused_adam.py +++ b/colossalai/nn/optimizer/fused_adam.py @@ -72,8 +72,8 @@ class FusedAdam(torch.optim.Optimizer): else: raise RuntimeError('FusedAdam requires cuda extensions') - def zero_grad(self): - if self.set_grad_none: + def zero_grad(self, set_to_none=False): + if set_to_none: for group in self.param_groups: for p in group['params']: p.grad = None diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index 1ff0544fd..47d690752 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -1,7 +1,55 @@ import torch + + from colossalai.utils import multi_tensor_applier +from colossalai.registry import OPTIMIZERS + +@OPTIMIZERS.register_module class HybridAdam(torch.optim.Optimizer): + """Implements Adam algorithm. + + Supports parameters updating on both GPU and CPU, depanding on the device of paramters. + But the parameters and gradients should on the same device: + * Parameters on CPU and gradients on CPU is allowed. + * Parameters on GPU and gradients on GPU is allowed. + * Parameters on GPU and gradients on CPU is **not** allowed. + + Requires ColossalAI to be installed via ``pip install .`` + + This version of Hybrid Adam is an hybrid of CPUAdam and FusedAdam. + * For parameters updating on CPU, it uses CPUAdam. + * For parameters updating on GPU, it uses FusedAdam. + * Hybird precision calculation of fp16 and fp32 is supported, eg fp32 parameters and fp16 gradients. + + :class:`colossalai.nn.optimizer.HybridAdam` may be used as a drop-in replacement for ``torch.optim.AdamW``, + or ``torch.optim.Adam`` with ``adamw_mode=False`` + + Adam was been proposed in `Adam: A Method for Stochastic Optimization`_. + + Arguments: + model_params (iterable): iterable of parameters of dicts defining + parameter groups. + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) NOT SUPPORTED yet in CPUAdam! + adamw_mode (boolean, optional): Apply L2 regularization or weight decay + True for decoupled weight decay(also known as AdamW) (default: True) + simd_log (boolean, optional): whether to show if you are using SIMD to + accelerate. (default: False) + + .. _Adam: A Method for Stochastic Optimization: + https://arxiv.org/abs/1412.6980 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + optimizer_id = 0 # Number of fp32 shards for per parameter # Param weight, grad, momentum and variance @@ -16,11 +64,6 @@ class HybridAdam(torch.optim.Optimizer): weight_decay=0, adamw_mode=True, simd_log=False): - """ - An implementation equivalent to `torch.optim.Adam`. - The difference is that model_params are sharded parameters belonging to a ShardedModelV2 instance. - The sharded param of model_params can resident on both CPU and CUDA(fused adam). - """ default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) super(HybridAdam, self).__init__(model_params, default_args) diff --git a/docs/colossalai/colossalai.nn.optimizer.hybrid_adam.rst b/docs/colossalai/colossalai.nn.optimizer.hybrid_adam.rst new file mode 100644 index 000000000..20508d664 --- /dev/null +++ b/docs/colossalai/colossalai.nn.optimizer.hybrid_adam.rst @@ -0,0 +1,5 @@ +colossalai.nn.optimizer.hybrid\_adam +==================================== + +.. automodule:: colossalai.nn.optimizer.hybrid_adam + :members: diff --git a/docs/colossalai/colossalai.nn.optimizer.rst b/docs/colossalai/colossalai.nn.optimizer.rst index f1c4e722b..d0063179d 100644 --- a/docs/colossalai/colossalai.nn.optimizer.rst +++ b/docs/colossalai/colossalai.nn.optimizer.rst @@ -13,5 +13,6 @@ colossalai.nn.optimizer colossalai.nn.optimizer.fused_adam colossalai.nn.optimizer.fused_lamb colossalai.nn.optimizer.fused_sgd + colossalai.nn.optimizer.hybrid_adam colossalai.nn.optimizer.lamb colossalai.nn.optimizer.lars