7.9 KiB
零气泡流水线并行
作者: Junwen Duan, Hongxin Liu
相关论文
介绍
零气泡(V Schedule): 与早期工作中的1F1B方案相比,零气泡流水线并行将B分成两个阶段(也称为激活梯度和权重梯度),形如1F1B1W这样的方案可以进一步减少气泡。
使用
我们将演示如何在 4 个 GPU 上使用带有 booster API 的 ZeroBubble
step 1. 引用仓库
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel
import colossalai
from colossalai.booster.booster import Booster
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import HybridParallelPlugin
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
step 2. 初始化分布式环境
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
step 3. 初始化模型优化器
建立我们的模型和优化器 我们创建了一个带有8层Decoder-Layer的 Llama。然后,使用get_v_schedule()函数创建PipelineGraph和Pipeline schedule。
# Global Param
NUM_BATCH = 8
NUM_TOK_PER_BATCH = 4
NUM_LAYERS = 8
HIDDEN_SIZE_PER_HEAD = 4
NUM_HEADS = 4
# Init Llama from huggingface
configuration = LlamaConfig(
hidden_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS,
intermediate_size=HIDDEN_SIZE_PER_HEAD * NUM_HEADS * 2,
num_hidden_layers=NUM_LAYERS,
num_attention_heads=NUM_HEADS,
num_key_value_heads=NUM_HEADS,
attn_implementation="flash_attention_2",
)
model = LlamaModel(configuration).cuda()
optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
step 4.初始化流水线Schedule
然后,我们需要使用 get_v_schedule() 函数创建 PipelineGraph 和 PipelineSchedule。我们需要用以下参数初始化 PipelineGraph。 x_cost 表示每个模型块的操作 x 所消耗的运行时间。 x_mem 表示每个模型块的操作 x 所消耗的内存量。 这些参数都是在流水线启动前估算并填入的。事实上,在模型的实际计算过程中,根据运行时间和内存成本可以获得更好的结果。 在下面的例子中,我们假设模型的正向、反向 B 和反向 W 的计算时间分别为 1、1、1,p2p 通信时间为 1。
# Init schedule
h, a, s = config.hidden_size, config.num_attention_heads, 1024
mem_f = 34 * h + 5 * a * s
mem_w = -32 * h
mem_b = -mem_w - mem_f
graph = PipelineGraph(
n_stage=pp_size,
n_micro=num_microbatches,
f_cost=1,
b_cost=1,
w_cost=1,
c_cost=1,
f_mem=mem_f,
b_mem=mem_b,
w_mem=mem_w,
)
zbv_schedule = graph.get_v_schedule()
step 5.初始化Booster
在初始化Plugin时输入pp_style="zbv",以使用ZeroBubble流水线并行。
plugin = HybridParallelPlugin(
pp_size=4,
num_microbatches=4,
tp_size=1,
sp_size=1,
zero_stage=1,
initial_scale=1,
find_unused_parameters=True,
pp_style="zbv",
scheduler_nodes=zbv_schedule,
num_model_chunks=2,
)
dp_size = plugin.dp_size
booster = Booster(plugin=plugin)
step 6.训练模型
steps = 10
for step in range(steps):
input_embeddings = torch.rand(
NUM_BATCH, NUM_TOK_PER_BATCH, HIDDEN_SIZE_PER_HEAD * NUM_HEADS, requires_grad=True
).cuda()
dist.all_reduce(
input_embeddings, group=plugin.pp_group
)
data_iter = iter([{"inputs_embeds": input_embeddings}])
output = booster.execute_pipeline(
data_iter,
model,
lambda x, y: x.last_hidden_state.mean(),
optimizer,
return_loss=True,
return_outputs=True,
)
optimizer.step()
optimizer.zero_grad()
进阶使用技巧
在 ColossalAI 中,通过使用MetaCache和混合并行的ZeroBubble,可以获得更好的训练性能。
1.在ZeroBubble中使用元数据缓存
在初始化Plugin时输入 "enable_metadata_cache=True",以便在ZeroBubble管道中使用元数据缓存。
plugin = HybridParallelPlugin(
pp_size=2,
num_microbatches=4,
tp_size=2,
sp_size=2,
zero_stage=1,
initial_scale=1,
enable_metadata_cache=True,
find_unused_parameters=True,
pp_style="zbv",
scheduler_nodes=zbv_schedule,
num_model_chunks=2,
)
2.同时使用ZeroBubble和混合并行
在初始化插件时传递 pp_size, tp_size, sp_size, 以便使用零气泡混合并行管道(HybridParallel with ZeroBubble Pipeline)。
plugin = HybridParallelPlugin(
pp_size=2,
num_microbatches=2,
tp_size=2,
sp_size=2,
zero_stage=1,
initial_scale=1,
find_unused_parameters=True,
pp_style="zbv",
scheduler_nodes=zbv_schedule,
num_model_chunks=2,
)
性能指标
HybridParallel Strategy | Pipeline Parallel | Sequence Parallel + Pipeline Parallel | Data Parallel + Pipeline Parallel | |||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
With 1F1B | 15.27 samples/sec | 17.22 samples/sec | 14.06 samples/sec | |||||||||||||||||||||||||||||||||||
With Zero Bubble | 17.36 samples/sec | 18.38 samples/sec | 14.44 samples/sec | |||||||||||||||||||||||||||||||||||
模型兼容性
Shardformer/Model | Bert | Blip2 | Bloom | Chatglm2 | Command | Deepseek | Falcon | GPT2 | Gptj | Llama | Mistral | Opt | Qwen2 | Sam | T5 | Vit | Whisper | |||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
ZeroBubble | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | ✔️ | |||||||||||||||||||||
API 参考
{{ autodoc:colossalai.pipeline.schedule.zero_bubble_pp.ZeroBubbleVPipeScheduler }}