2021-10-28 16:21:23 +00:00
|
|
|
# modified from https://github.com/NVIDIA/apex/blob/master/apex/optimizers/fused_sgd.py
|
|
|
|
import torch
|
|
|
|
from torch.optim.optimizer import Optimizer, required
|
|
|
|
|
|
|
|
from colossalai.registry import OPTIMIZERS
|
Develop/experiments (#59)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000
* Integrate 1d tensor parallel in Colossal-AI (#39)
* fixed 1D and 2D convergence (#38)
* optimized 2D operations
* fixed 1D ViT convergence problem
* Feature/ddp (#49)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* support torch ddp
* fix loss accumulation
* add log for ddp
* change seed
* modify timing hook
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* Feature/pipeline (#40)
* remove redundancy func in setup (#19) (#20)
* use env to control the language of doc (#24) (#25)
* Support TP-compatible Torch AMP and Update trainer API (#27)
* Add gradient accumulation, fix lr scheduler
* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)
* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes
* fixed trainer
* Revert "fixed trainer"
This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.
* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)
* add explanation for ViT example (#35) (#36)
* optimize communication of pipeline parallel
* fix grad clip for pipeline
Co-authored-by: Frank Lee <somerlee.9@gmail.com>
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)
* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset
* update api for better usability (#58)
update api for better usability
Co-authored-by: 1SAA <c2h214748@gmail.com>
Co-authored-by: ver217 <lhx0217@gmail.com>
Co-authored-by: puck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: BoxiangW <45734921+BoxiangW@users.noreply.github.com>
2021-12-09 07:08:29 +00:00
|
|
|
from colossalai.utils import multi_tensor_applier
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
|
|
|
|
@OPTIMIZERS.register_module
|
|
|
|
class FusedSGD(Optimizer):
|
|
|
|
r"""Implements stochastic gradient descent (optionally with momentum).
|
|
|
|
|
2023-01-09 09:13:53 +00:00
|
|
|
`FusedSGD` requires CUDA extensions which can be built during installation or runtime.
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
This version of fused SGD implements 2 fusions.
|
|
|
|
|
|
|
|
* Fusion of the SGD update's elementwise operations
|
|
|
|
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
|
|
|
|
|
|
|
|
:class:`colossalai.nn.optimizer.FusedSGD` may be used as a drop-in replacement for ``torch.optim.SGD``
|
|
|
|
|
2022-11-17 05:42:33 +00:00
|
|
|
:class:`colossalai.nn.optimizer.FusedSGD` may be used with or without Amp.
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
Nesterov momentum is based on the formula from
|
|
|
|
`On the importance of initialization and momentum in deep learning`__.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
params (iterable): iterable of parameters to optimize or dicts defining
|
|
|
|
parameter groups
|
|
|
|
lr (float): learning rate
|
|
|
|
momentum (float, optional): momentum factor (default: 0)
|
|
|
|
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
|
|
|
dampening (float, optional): dampening for momentum (default: 0)
|
|
|
|
nesterov (bool, optional): enables Nesterov momentum (default: False)
|
|
|
|
|
|
|
|
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
|
|
|
|
|
|
|
|
.. note::
|
|
|
|
The implementation of SGD with Momentum/Nesterov subtly differs from
|
|
|
|
Sutskever et. al. and implementations in some other frameworks.
|
|
|
|
Considering the specific case of Momentum, the update can be written as
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
v = \rho * v + g \\
|
|
|
|
p = p - lr * v
|
|
|
|
|
|
|
|
where p, g, v and :math:`\rho` denote the parameters, gradient,
|
|
|
|
velocity, and momentum respectively.
|
|
|
|
This is in contrast to Sutskever et. al. and
|
|
|
|
other frameworks which employ an update of the form
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
v = \rho * v + lr * g \\
|
|
|
|
p = p - v
|
|
|
|
|
|
|
|
The Nesterov version is analogously modified.
|
|
|
|
"""
|
|
|
|
|
2022-04-01 08:27:03 +00:00
|
|
|
def __init__(self,
|
|
|
|
params,
|
|
|
|
lr=required,
|
|
|
|
momentum=0,
|
|
|
|
dampening=0,
|
|
|
|
weight_decay=0,
|
|
|
|
nesterov=False,
|
2022-06-20 03:19:38 +00:00
|
|
|
wd_after_momentum=False):
|
2021-10-28 16:21:23 +00:00
|
|
|
if lr is not required and lr < 0.0:
|
|
|
|
raise ValueError("Invalid learning rate: {}".format(lr))
|
|
|
|
if momentum < 0.0:
|
|
|
|
raise ValueError("Invalid momentum value: {}".format(momentum))
|
|
|
|
if weight_decay < 0.0:
|
2022-04-01 08:27:03 +00:00
|
|
|
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
2021-10-28 16:21:23 +00:00
|
|
|
|
2022-04-01 08:27:03 +00:00
|
|
|
defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay, nesterov=nesterov)
|
2021-10-28 16:21:23 +00:00
|
|
|
if nesterov and (momentum <= 0 or dampening != 0):
|
2022-04-01 08:27:03 +00:00
|
|
|
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
|
2021-10-28 16:21:23 +00:00
|
|
|
super(FusedSGD, self).__init__(params, defaults)
|
|
|
|
|
|
|
|
self.wd_after_momentum = wd_after_momentum
|
|
|
|
|
|
|
|
if multi_tensor_applier.available:
|
2023-01-06 12:50:26 +00:00
|
|
|
from colossalai.kernel.op_builder import FusedOptimBuilder
|
|
|
|
fused_optim = FusedOptimBuilder().load()
|
2022-12-23 12:57:41 +00:00
|
|
|
|
2021-10-28 16:21:23 +00:00
|
|
|
# Skip buffer
|
2022-04-01 08:27:03 +00:00
|
|
|
self._dummy_overflow_buf = torch.tensor([0],
|
|
|
|
dtype=torch.int,
|
|
|
|
device=self.param_groups[0]["params"][0].device)
|
2022-12-23 09:07:03 +00:00
|
|
|
self.multi_tensor_sgd = fused_optim.multi_tensor_sgd
|
2021-10-28 16:21:23 +00:00
|
|
|
else:
|
2022-01-13 08:47:17 +00:00
|
|
|
raise RuntimeError('FusedSGD requires cuda extensions')
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
def __setstate__(self, state):
|
|
|
|
super(FusedSGD, self).__setstate__(state)
|
|
|
|
for group in self.param_groups:
|
|
|
|
group.setdefault('nesterov', False)
|
|
|
|
|
|
|
|
def get_momentums(self, params):
|
|
|
|
momentums = []
|
|
|
|
first_run = True
|
|
|
|
for p in params:
|
|
|
|
param_state = self.state[p]
|
|
|
|
# torch.optim.SGD initializes momentum in the main loop, we have
|
|
|
|
# to do it here, and track whether or not we've done so, so that
|
|
|
|
# momentum application can be skipped in the main kernel.
|
|
|
|
if 'momentum_buffer' not in param_state:
|
|
|
|
first_run = True
|
2022-07-29 11:33:24 +00:00
|
|
|
buf = param_state['momentum_buffer'] = torch.zeros_like(p)
|
2021-10-28 16:21:23 +00:00
|
|
|
momentums.append(buf)
|
|
|
|
else:
|
|
|
|
first_run = False
|
|
|
|
momentums.append(param_state['momentum_buffer'])
|
|
|
|
return momentums, first_run
|
|
|
|
|
|
|
|
def step(self, closure=None):
|
|
|
|
"""Performs a single optimization step.
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
closure (callable, optional): A closure that reevaluates the model
|
|
|
|
and returns the loss.
|
|
|
|
"""
|
|
|
|
loss = None
|
|
|
|
if closure is not None:
|
|
|
|
loss = closure()
|
|
|
|
|
2022-06-20 03:19:38 +00:00
|
|
|
for group in self.param_groups:
|
2021-10-28 16:21:23 +00:00
|
|
|
weight_decay = group['weight_decay']
|
|
|
|
momentum = group['momentum']
|
|
|
|
dampening = group['dampening']
|
|
|
|
nesterov = group['nesterov']
|
|
|
|
|
|
|
|
# For each group, there are 3 possible combinations we need to consider:
|
2022-06-20 03:19:38 +00:00
|
|
|
# grad_type, param_to_update_type, momentum_type
|
|
|
|
# 1. fp16, fp16, fp16
|
|
|
|
# 2. fp32, fp32, fp32
|
|
|
|
# 3. fp16, fp32, fp32
|
|
|
|
g_l, p_l = [], []
|
|
|
|
for p in group['params']:
|
|
|
|
if p.grad is None:
|
|
|
|
continue
|
|
|
|
if p.grad.data.is_sparse:
|
|
|
|
raise RuntimeError('FusedSGD does not support sparse gradients')
|
|
|
|
g_l.append(p.grad)
|
|
|
|
p_l.append(p)
|
|
|
|
m_l, first_run = self.get_momentums(p_l)
|
|
|
|
multi_tensor_applier(self.multi_tensor_sgd, self._dummy_overflow_buf, [g_l, p_l, m_l], weight_decay,
|
|
|
|
momentum, dampening, group['lr'], nesterov, first_run, self.wd_after_momentum, 1.0)
|
2021-10-28 16:21:23 +00:00
|
|
|
|
|
|
|
return loss
|