You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/docs/source/zh-Hans/features/distributed_optimizers.md

5.4 KiB

分布式优化器

Author: Wenxuan Tan, Junwen Duan, Renjie Mao

相关论文

介绍

除了广泛采用的Adam和SGD外许多现代优化器需要逐层统计信息以有效更新参数因此无法直接应用于模型层在多个设备上分片的并行设置。我们以提供了优化的分布式实现并且通过插件与Tensor Parallel、DDP和ZeRO无缝集成。

优化器

Adafactor 是一种首次采用非负矩阵分解NMF的 Adam 变体用于减少内存占用。CAME 通过引入一个置信度矩阵来改进 NMF 的效果。GaLore 通过将梯度投影到低秩空间,并使用 8 位块状量化进一步减少内存占用。Lamb 允许使用巨大的批量大小而不失准确性,通过按其 Lipschitz 常数的倒数界定的逐层自适应更新实现

API 参考

{{ autodoc:colossalai.nn.optimizer.distributed_adafactor.DistributedAdaFactor }} {{ autodoc:colossalai.nn.optimizer.distributed_lamb.DistributedLamb }} {{ autodoc:colossalai.nn.optimizer.distributed_galore.DistGaloreAwamW }} {{ autodoc:colossalai.nn.optimizer.distributed_came.DistributedCAME }}

使用

We now demonstrate how to use Distributed Adafactor with booster API combining Tensor Parallel and ZeRO 2 with 4 GPUs.

step 1. 导包

from transformers import LlamaModel, LlamaConfig
from colossalai.nn.optimizer.distributed_adafactor import DistributedAdaFactor
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
import colossalai
import torch

step 2. 初始化分布式

We need to initialize distributed environment. For demo purpose, we use colossal run --nproc_per_node 4. You can refer to Launch Colossal-AI

colossalai.launch_from_torch()

step 3. 初始化模型和优化器

Build our model. We created an MLP using two Linear Layer.

configuration = LlamaConfig()
model = LlamaModel(configuration).cuda()
criterion = lambda x: x.mean()
dist_optim = DistributedAdaFactor(model.parameters())

step 4.初始化booster和plugin

plugin = HybridParallelPlugin(tp_size=2, zero_stage=2, pp_size=1, enable_all_optimization=True)
booster = Booster(plugin=plugin)
# You should also pass in your own dataset.
model, dist_optim, criterion, dataloader, _ = booster.boost(model, dist_optim, criterion)

step 5.训练

steps = 10
for step in range(steps):
    input_ids = torch.ones(1, 100, device="cuda", dtype=torch.int)
    attention_mask = input_ids.clone()
    outputs = model(input_ids.cuda(), attention_mask.cuda())
    loss = criterion(outputs.last_hidden_state)
    booster.backward(loss, dist_optim)
    dist_optim.step()
    dist_optim.zero_grad()

GaLore的特殊初期

对于 GaLore我们需要为每个参数组指定投影rank以及量化和分页优化器参数。有关量化的详细信息请参考 bitandbytes.

from colossalai.nn.optimizer.galore import get_galore_param_groups
from colossalai.nn.optimizer import DistGaloreAwamW
optim = DistGaloreAwamW(
    get_galore_param_groups(model, decay=1e-2, rank=8),
    lr=lr,
    betas=(beta1, beta2),
    eps=eps,
    nbits=8,
    percentile_clipping=100,
    block_wise=True,
    min_8bit_size=4096,
)

兼容性

Model/Feature Lamb GaLore Adafactor CAME
Hybrid Parallel
Plugin
✔️ ✔️ ✔️ ✔️
Low Level Zero
Plugin
✔️ ✔️ ✔️
Torch DDP
Plugin
✔️ ✔️ ✔️ ✔️
Gemini
Plugin
Moe Hybrid
Plugin