ColossalAI/docs/source/zh-Hans/features/shardformer.md

18 KiB
Raw Blame History

Shardformer

Author: Baizhou Zhang, Bin Jia

预备知识

示例代码

相关论文

简介

在训练LLaMa-2 70B或OPT 175B等大型Transformer模型时为了满足GPU内存的限制将大型模型划分为更小的分片的模型并行方法包括张量并行以及流水线并行是必不可少的。然而对于不熟悉分布式训练的用户来说手动剪切模型并重写其前向/反向逻辑可能很困难。与此同时Huggingface transformers开源库正在逐渐成为用户模型来源的首选大部分主流大型模型都已在Huggingface transformers模型库中开源。

出于这种动机ColossalAI团队开发了Shardformer该功能可以自动为HuggingFace中主流的Transformer模型进行封装用于张量并行以及流水线并行的训练策略。如此一来对系统了解不多的用户也可以轻松地在transformers模型上进行并行训练只需几行代码用户就可以将模型转变为并行训练的状态。此外Shardformer也包括了多种优化工具用于在前向/后向的传递过程中实现加速和节省内存。

支持信息

模型/功能 兼容性矩阵:

Model/Feature Tensor
Parallel
Pipeline
Parallel
Lazy
Initialization
xFormers Flash
Attention 2
JIT Fused
Operators
Fused
LayerNorm
Sequence
Parallel
Sequence
Overlap
Llama V1/V2 ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
OPT ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
BLOOM ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
ChatGLM 2 ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
BERT ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
GPT 2 ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
T5 ✔️ ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
ViT ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
Whisper ✔️ ✔️ ✔️ ✔️ ✔️ ✔️
SAM ✔️ ✔️ ✔️ ✔️ ✔️
Blip2 ✔️ ✔️ ✔️ ✔️ ✔️

我们计划在不久后为Shardformer支持的模型:

  • RoBERTa
  • ALBERT
  • ERNIE
  • GPT Neo
  • GPT-J
  • BEiT
  • SwinTransformer V1/V2
  • qwen

随着未来更多模型和优化工具的出现,我们支持的模型/优化工具将会变得越来越多。如果您对我们应该支持的模型/优化工具有任何建议,欢迎在项目的Issues板块参与讨论。

用法

Shardformer的参数配置

Shardformer的配置由类ShardConfig的参数控制:

{{ autodoc:colossalai.shardformer.ShardConfig }}

如果您想启用 Apex Fused Layernorm请安装 apex。如果您想启用 flash attention请安装 flash_attn。此外xFormers 的 cutlass_op 可以作为Flash Attention的补充优化方式。

启动Shardformer

1. 通过Booster启动Shardformer (推荐)

通过用HybridParallelPlugin初始化的Booster来启动Shardformer是最推荐的用法。其主要原因是如果不调用Boosterexecute_pipeline方法,流水线并行就无法正常工作。此外,HybridParallelPlugin提供了将Shardformer的功能与其他功能例如混合精度训练或Zero相结合的能力。

这里是一个通过HybridParallelPlugin启动Shardformer的示例。 移动到示例的根目录下,执行命令:

torchrun --standalone --nproc_per_node 4  finetune.py --target_f1 0.86 --plugin "hybrid_parallel" --model_type "bert"

你便可以微调一个被Shardformer封装过的Bert模型而封装的操作是由HybridParallelPlugin完成的。

接下来一起深入挖掘一下finetune.py里的代码:

main函数中,混合并行的插件通过以下的代码创建

...
elif args.plugin == "hybrid_parallel":
    # modify the param accordingly for finetuning test cases
    plugin = HybridParallelPlugin(
        tp_size=1,
        pp_size=2,
        num_microbatches=None,
        microbatch_size=1,
        enable_all_optimization=True,
        zero_stage=1,
        precision="fp16",
        initial_scale=1,
    )

在这里你可以通过设置不同的tp_size, pp_sizezero_stage来改变插件的配置。更多关于插件配置的信息可以在Booster 插件文档中被找到。

当流水并行不被启用的时候,训练的流程和其他的插件是一样的 先用Booster封装模型和优化器再用正常的方式做前向和后向传递。然而当流水线并行被启用的时候有几处不同于寻常情况的用法

  1. 在进行前向和后向之前criterion函数loss函数需要被处理以满足流水线并行的传参要求:

    def _criterion(outputs, inputs):
        outputs = output_transform_fn(outputs)
        loss = criterion(outputs)
        return loss
    
  2. train_epoch 函数中, dataloader 在进行流水线的前向后向操作之前需要被转换为 Iterator 类:

    train_dataloader_iter = iter(train_dataloader)
    
  3. 通过调用Booster.execute_pipeline 方法来执行前向和后向传递:

    outputs = booster.execute_pipeline(
        train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
    )
    

    该方法会自动执行后向传递,所以在执行该方法后不需要再调用 loss.backward()方法。 更多关于 Booster.execute_pipeline 的信息可以参考 Booster API 文档

2. 通过Shardformer API启动Shardformer (不推荐)

您还可以通过手动调用Shardformer API的方式启动Shardformer。然而我们并不推荐这种用法因为流水线并行在没有Booster的情况下无法正常运行。

这里 是一个通过调用Shardformer的API启动Shardformer的示例。 在示例代码的train函数中,模型被以下的几行代码进行封装:

...
if dist.get_world_size() > 1:
    tp_group = dist.new_group(backend="nccl")

    # First create configuration for Shardformer
    shard_config = ShardConfig(
        tensor_parallel_process_group=tp_group,
        enable_tensor_parallelism=True,
        enable_all_optimization=True
    )

    # Then create ShardFormer object with created config
    shard_former = ShardFormer(shard_config=shard_config)

    # Finally shard the model using ShardFormer.optimize method
    model, _ = shard_former.optimize(model)
...

注意事项

  1. 当启用流水线并行时,请不要用常规方式(model(input)loss.backward())进行前向/后向传递,这样会导致未知的错误。这种情形下请通过调用booster.execute_pipeline方法来进行前向/后向传递。

  2. 当使用Shardformer处理GPT2ForSequenceClassificationViTForImageClassification等分类模型时请确保labels的总数为张量并行度的整数倍否则Shardformer无法正确地处理classifier层。一个简单的修复方法就是在transformers的config中添加虚拟的标签。这一bug将在 Shardformer的未来版本中修复。

  3. 训练ChatGLM-2 6B的情况有点特殊由于Huggingface Transformers 目前尚未正式支持ChatGLM。在使用Shardformer训练ChatGLM-2时请通过以下方式导入config/model的类

    from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
    from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
    

    并且使用这些导入的类初始化模型。

Shardformer的工作原理

设计思想

通常来说Shardformer通过以下四种“替换”进行工作

  1. 用我们设计的分布式模块替换原始的PyTorch模块例如nn.Linearnn.Embedding)。 分布式模块保持与原始模块相同的属性但分布式模块会用新的参数替换原始模块的参数。新的前向函数将取代原来的前向函数用于执行分布式计算例如在张量并行下执行线性层的split/gather操作。每个分布式模块都应当实现其from_native_module静态方法以将PyTorch模块转换为其相应的分布式模块。

  2. 将原始Huggingface Transformers中间层的属性为适用于并行训练的属性。例如当使用并行度为2的张量并行训练LlaMa-2时,LlamaDecoderLayer 的属性num_heads(每一层注意力头的数量)应替换为model.config.num_attention_heads // 2

  3. 将原来Huggingface transformers库实现的前向函数替换为我们定制的前向函数。前向函数的替换对于流水线并行性至关重要因为流水线并行需要特殊的前向函数去在不同的流水线阶段之间传递中间的隐藏状态。此外可以通过我们定制的前向函数将例如flash attention或序列并行的优化方法注入到前向的过程中。

  4. 将完整的模型参数和优化器状态替换为只由当前设备控制的部分模型参数和优化器状态。通过执行ModelSharder.shard方法,当前设备仅会保留它应该处理的那部分模型参数。具体来说,这部分参数可以是使用张量并行时分配到当前机器的参数分片,或者使用流水线并行时当前流水线阶段的模型参数,或者兼而有之。除此之外的所有其他参数都被释放,用于节省内存的空间。 如此一来,优化器只会计算保留的部分参数对应的状态,从而进一步节省内存的使用。

所有这些替换都是通过手动编写的策略和前向函数来实现的。如果您想更深入地研究Shardformer的设计方案或者定制您自己的Shardformer策略请参考Shardformer 开发者文档流水并行设计方案以获得更多细节。

序列并行 Sequence Parallelism

序列并行是Shardformer支持的一种特殊的优化方法。在Shardformer中,序列并行与此处稍有不同后者侧重于ring attention。在Shardformer序列并行仅与1D张量并行一起使用以进一步减少计算中activation的内存占用。

  1. 在普通的1D张量并行中,有两个通信操作g\vec{g}g在反向传播中进行一次全局归约以获取来自所有设备的梯度,而\vec{g}在正向传播中进行一次All-Reduce以获取来自所有设备的输出。

  2. 当使用序列并行时,\vec{g}需要在正向传播过程中进行All-Gather以获取序列维度上的输入并在反向传播过程中进行Reduce-Scatter以分割梯度。\vec{g}需要进行Reduce-Scatter以将序列维度上的行线性层输出分割到所有设备上并进行All-Gather以获取完整的梯度。

  3. 使用NCCL的All-reduce实现采用了Ring All-Reduce方法由一次Reduce-Scatter和一次All-Gather组成两者的开销相等。因此与序列并行和张量并行相比它并不会引入额外的通信开销。

  4. 需要注意的一点是,在张量并行的 Column Linear 层中进行序列并行时,梯度的反向计算过程中需要获取完整的输入。在前向传播过程中,仅保留沿序列维度分割的输入部分,张量的形状例如(batch, sequence\_len/k, hidden\_states)。因此,需要进行额外的全局收集操作以获取完整的输入进行梯度计算。但是,在实现中,可以将梯度计算与全局收集通信操作重叠,这不会引入额外的通信开销(对应Shardformer中的enable_sequence_overlap参数)。