ColossalAI/docs/source/zh-Hans/features/distributed_optimizers.md

142 lines
5.4 KiB
Markdown
Raw Normal View History

[Feature] Distributed optimizers: Lamb, Galore, CAME and Adafactor (#5694) * [feat] Add distributed lamb; minor fixes in DeviceMesh (#5476) * init: add dist lamb; add debiasing for lamb * dist lamb tester mostly done * all tests passed * add comments * all tests passed. Removed debugging statements * moved setup_distributed inside plugin. Added dist layout caching * organize better --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [hotfix] Improve tester precision by removing ZeRO on vanilla lamb (#5576) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [optim] add distributed came (#5526) * test CAME under LowLevelZeroOptimizer wrapper * test CAME TP row and col pass * test CAME zero pass * came zero add master and worker param id convert * came zero test pass * came zero test pass * test distributed came passed * reform code, Modify some expressions and add comments * minor fix of test came * minor fix of dist_came and test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix of dist_came and test * rebase dist-optim * rebase dist-optim * fix remaining comments * add test dist came using booster api --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [optim] Distributed Adafactor (#5484) * [feature] solve conflict; update optimizer readme; * [feature] update optimize readme; * [fix] fix testcase; * [feature] Add transformer-bert to testcase;solve a bug related to indivisible shape (induction in use_zero and tp is row parallel); * [feature] Add transformers_bert model zoo in testcase; * [feature] add user documentation to docs/source/feature. * [feature] add API Reference & Sample to optimizer Readme; add state check for bert exam; * [feature] modify user documentation; * [fix] fix readme format issue; * [fix] add zero=0 in testcase; cached augment in dict; * [fix] fix percision issue; * [feature] add distributed rms; * [feature] remove useless comment in testcase; * [fix] Remove useless test; open zero test; remove fp16 test in bert exam; * [feature] Extract distributed rms function; * [feature] add booster + lowlevelzeroPlugin in test; * [feature] add Start_with_booster_API case in md; add Supporting Information in md; * [fix] Also remove state movement in base adafactor; * [feature] extract factor function; * [feature] add LowLevelZeroPlugin test; * [fix] add tp=False and zero=True in logic; * [fix] fix use zero logic; * [feature] add row residue logic in column parallel factor; * [feature] add check optim state func; * [feature] Remove duplicate logic; * [feature] update optim state check func and percision test bug; * [fix] update/fix optim state; Still exist percision issue; * [fix] Add use_zero check in _rms; Add plugin support info in Readme; Add Dist Adafactor init Info; * [feature] removed print & comments in utils; * [feature] uodate Readme; * [feature] add LowLevelZeroPlugin test with Bert model zoo; * [fix] fix logic in _rms; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [fix] remove comments in testcase; * [feature] add zh-Han Readme; --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; (#5676) * [feature] daily update; * [fix] fix dist came; * [feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; * [fix] open rms; fix low level zero test; fix dist came test function name; * [fix] remove redundant test; * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Feature] Add Galore (Adam, Adafactor) and distributed GaloreAdamW8bit (#5570) * init: add dist lamb; add debiasing for lamb * dist lamb tester mostly done * all tests passed * add comments * all tests passed. Removed debugging statements * moved setup_distributed inside plugin. Added dist layout caching * organize better * update comments * add initial distributed galore * add initial distributed galore * add galore set param utils; change setup_distributed interface * projected grad precision passed * basic precision tests passed * tests passed; located svd precision issue in fwd-bwd; banned these tests * Plugin DP + TP tests passed * move get_shard_dim to d_tensor * add comments * remove useless files * remove useless files * fix zero typo * improve interface * remove moe changes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix import * fix deepcopy * update came & adafactor to main * fix param map * fix typo --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [Hotfix] Remove one buggy test case from dist_adafactor for now (#5692) Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --------- Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: chongqichuizi875 <107315010+chongqichuizi875@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: duanjunwen <54985467+duanjunwen@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
2024-05-14 05:52:45 +00:00
# 分布式优化器
Author: Wenxuan Tan, Junwen Duan, Renjie Mao
**相关论文**
- [Adafactor: Adaptive Learning Rates with Sublinear Memory Cost](https://arxiv.org/abs/1804.04235)
- [CAME: Confidence-guided Adaptive Memory Efficient Optimization] (https://arxiv.org/abs/2307.02047)
- [GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection] (https://arxiv.org/abs/2403.03507)
- [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes] (https://arxiv.org/pdf/1904.00962)
## 介绍
除了广泛采用的Adam和SGD外许多现代优化器需要逐层统计信息以有效更新参数因此无法直接应用于模型层在多个设备上分片的并行设置。我们以提供了优化的分布式实现并且通过插件与Tensor Parallel、DDP和ZeRO无缝集成。
## 优化器
Adafactor 是一种首次采用非负矩阵分解NMF的 Adam 变体用于减少内存占用。CAME 通过引入一个置信度矩阵来改进 NMF 的效果。GaLore 通过将梯度投影到低秩空间,并使用 8 位块状量化进一步减少内存占用。Lamb 允许使用巨大的批量大小而不失准确性,通过按其 Lipschitz 常数的倒数界定的逐层自适应更新实现
## API 参考
{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }}
{{ autodoc:colossalai.nn.optimizer.distributed_lamb.DistributedLamb }}
{{ autodoc:colossalai.nn.optimizer.distributed_galore.DistGaloreAwamW }}
{{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }}
## 使用
We now demonstrate how to use Distributed Adafactor with booster API combining Tensor Parallel and ZeRO 2 with 4 GPUs.
### step 1. 导包
```python
from transformers import LlamaModel, LlamaConfig
from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
import colossalai
import torch
```
### step 2. 初始化分布式
We need to initialize distributed environment. For demo purpose, we use `colossal run --nproc_per_node 4`. You can refer to [Launch Colossal-AI](../basics/launch_colossalai.md)
```python
colossalai.launch_from_torch()
```
### step 3. 初始化模型和优化器
Build our model. We created an MLP using two Linear Layer.
```python
configuration = LlamaConfig()
model = LlamaModel(configuration).cuda()
criterion = lambda x: x.mean()
dist_optim = DistributedAdaFactor(model.parameters())
```
### step 4.初始化booster和plugin
```python
plugin = HybridParallelPlugin(tp_size=2, zero_stage=2, pp_size=1, enable_all_optimization=True)
booster = Booster(plugin=plugin)
# You should also pass in your own dataset.
model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion)
```
### step 5.训练
```python
steps = 10
for step in range(steps):
input_ids = torch.ones(1, 100, device="cuda", dtype=torch.int)
attention_mask = input_ids.clone()
outputs = model(input_ids.cuda(), attention_mask.cuda())
loss = criterion(outputs.last_hidden_state)
booster.backward(loss, dist_optim)
dist_optim.step()
dist_optim.zero_grad()
```
### GaLore的特殊初期
对于 GaLore我们需要为每个参数组指定投影rank以及量化和分页优化器参数。有关量化的详细信息请参考 bitandbytes.
```python
from colossalai.nn.optimizer.galore import get_galore_param_groups
from colossalai.nn.optimizer import DistGaloreAwamW
optim = DistGaloreAwamW(
get_galore_param_groups(model, decay=1e-2, rank=8),
lr=lr,
betas=(beta1, beta2),
eps=eps,
nbits=8,
percentile_clipping=100,
block_wise=True,
min_8bit_size=4096,
)
```
## 兼容性
<table>
<tr>
<th nowrap="nowrap">Model/Feature</th>
<th nowrap="nowrap" align="center" title="Lamb">Lamb</th>
<th nowrap="nowrap" align="center" title="GaLore">GaLore</th>
<th nowrap="nowrap" align="center" title="Adafactor">Adafactor</th>
<th nowrap="nowrap" align="center" title="CAME">CAME</th>
</tr>
<tr>
<td nowrap="nowrap">Hybrid Parallel<br />Plugin</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Low Level Zero<br />Plugin</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Torch DDP<br />Plugin</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
</tr>
<tr>
<td nowrap="nowrap">Gemini<br />Plugin</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
</tr>
<tr>
<td nowrap="nowrap">Moe Hybrid<br />Plugin</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
</tr>
<tr>
<td colspan="39"></td>
</tr>
</table>
<!-- doc-test-command: colossalai run --nproc_per_node 4 distributed_optimizers.py -->