Merge pull request #4757 from ppt0011/main

[doc] explain suitable use case for each plugin
pull/4766/head
ppt0011 2023-09-20 11:57:43 +08:00 committed by GitHub
commit 07c2e3d09c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 56 additions and 43 deletions

View File

@ -1,6 +1,6 @@
# Booster Plugins
Author: [Hongxin Liu](https://github.com/ver217), [Baizhou Zhang](https://github.com/Fridge003)
Author: [Hongxin Liu](https://github.com/ver217), [Baizhou Zhang](https://github.com/Fridge003), [Pengtai Xu](https://github.com/ppt0011)
**Prerequisite:**
- [Booster API](./booster_api.md)
@ -11,14 +11,23 @@ As mentioned in [Booster API](./booster_api.md), we can use booster plugins to c
We currently provide the following plugins:
- [Low Level Zero Plugin](#low-level-zero-plugin): It wraps the `colossalai.zero.low_level.LowLevelZeroOptimizer` and can be used to train models with zero-dp. It only supports zero stage-1 and stage-2.
- [Gemini Plugin](#gemini-plugin): It wraps the [Gemini](../features/zero_with_chunk.md) which implements Zero-3 with chunk-based and heterogeneous memory management.
- [Torch DDP Plugin](#torch-ddp-plugin): It is a wrapper of `torch.nn.parallel.DistributedDataParallel` and can be used to train models with data parallelism.
- [Torch FSDP Plugin](#torch-fsdp-plugin): It is a wrapper of `torch.distributed.fsdp.FullyShardedDataParallel` and can be used to train models with zero-dp.
- [Low Level Zero Plugin](#low-level-zero-plugin): It wraps the `colossalai.zero.low_level.LowLevelZeroOptimizer` and can be used to train models with zero-dp. It only supports zero stage-1 and stage-2.
- [Gemini Plugin](#gemini-plugin): It wraps the [Gemini](../features/zero_with_chunk.md) which implements Zero-3 with chunk-based and heterogeneous memory management.
- [Hybrid Pararllel Plugin](#hybrid-parallel-plugin): It provides a tidy interface that integrates the power of Shardformer, pipeline manager, mixied precision training, TorchDDP and Zero stage 1/2 feature. With this plugin, transformer models can be easily trained with any combination of tensor parallel, pipeline parallel and data parallel (DDP/Zero) efficiently, along with various kinds of optimization tools for acceleration and memory saving. Detailed information about supported parallel strategies and optimization tools is explained in the section below.
More plugins are coming soon.
## Choosing Your Plugin
Generally only one plugin is used to train a model. Our recommended use case for each plugin is as follows.
- [Torch DDP Plugin](#torch-ddp-plugin): It is suitable for models with less than 2 billion parameters (e.g. Bert-3m, GPT2-1.5b).
- [Torch FSDP Plugin](#torch-fsdp-plugin) / [Low Level Zero Plugin](#low-level-zero-plugin): It is suitable for models with less than 10 billion parameters (e.g. GPTJ-6b, MegatronLM-8b).
- [Gemini Plugin](#gemini-plugin): It is suitable for models with more than 10 billion parameters (e.g. TuringNLG-17b) and is ideal for scenarios with **high cross-node bandwidth and medium to small-scale clusters (below a thousand cards)** (e.g. Llama2-70b).
- [Hybrid Pararllel Plugin](#hybrid-parallel-plugin): It is suitable for models with more than 60 billion parameters, or special models such as those with exceptionally long sequences, very large vocabularies, and is best suited for scenarios with **low cross-node bandwidth and large-scale clusters (a thousand cards or more)** (e.g. GPT3-175b, Bloom-176b).
## Plugins
### Low Level Zero Plugin
@ -50,24 +59,6 @@ This plugin implements Zero-3 with chunk-based and heterogeneous memory manageme
{{ autodoc:colossalai.booster.plugin.GeminiPlugin }}
### Torch DDP Plugin
More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel).
{{ autodoc:colossalai.booster.plugin.TorchDDPPlugin }}
### Torch FSDP Plugin
> ⚠ This plugin is not available when torch version is lower than 1.12.0.
> ⚠ This plugin does not support save/load sharded model checkpoint now.
> ⚠ This plugin does not support optimizer that use multi params group.
More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/fsdp.html).
{{ autodoc:colossalai.booster.plugin.TorchFSDPPlugin }}
### Hybrid Parallel Plugin
@ -87,5 +78,22 @@ This plugin implements the combination of various parallel training strategies a
{{ autodoc:colossalai.booster.plugin.HybridParallelPlugin }}
### Torch DDP Plugin
More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel).
{{ autodoc:colossalai.booster.plugin.TorchDDPPlugin }}
### Torch FSDP Plugin
> ⚠ This plugin is not available when torch version is lower than 1.12.0.
> ⚠ This plugin does not support save/load sharded model checkpoint now.
> ⚠ This plugin does not support optimizer that use multi params group.
More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/fsdp.html).
{{ autodoc:colossalai.booster.plugin.TorchFSDPPlugin }}
<!-- doc-test-command: echo -->

View File

@ -1,6 +1,7 @@
# Booster 插件
作者: [Hongxin Liu](https://github.com/ver217), [Baizhou Zhang](https://github.com/Fridge003)
作者: [Hongxin Liu](https://github.com/ver217), [Baizhou Zhang](https://github.com/Fridge003), [Pengtai Xu](https://github.com/ppt0011)
**前置教程:**
- [Booster API](./booster_api.md)
@ -11,14 +12,20 @@
我们现在提供以下插件:
- [Low Level Zero 插件](#low-level-zero-插件): 它包装了 `colossalai.zero.low_level.LowLevelZeroOptimizer`,可用于使用 Zero-dp 训练模型。它仅支持 Zero 阶段1和阶段2。
- [Gemini 插件](#gemini-插件): 它包装了 [Gemini](../features/zero_with_chunk.md)Gemini 实现了基于Chunk内存管理和异构内存管理的 Zero-3。
- [Torch DDP 插件](#torch-ddp-插件): 它包装了 `torch.nn.parallel.DistributedDataParallel` 并且可用于使用数据并行训练模型。
- [Torch FSDP 插件](#torch-fsdp-插件): 它包装了 `torch.distributed.fsdp.FullyShardedDataParallel` 并且可用于使用 Zero-dp 训练模型。
- [Low Level Zero 插件](#low-level-zero-插件): 它包装了 `colossalai.zero.low_level.LowLevelZeroOptimizer`,可用于使用 Zero-dp 训练模型。它仅支持 Zero 阶段1和阶段2。
- [Gemini 插件](#gemini-插件): 它包装了 [Gemini](../features/zero_with_chunk.md)Gemini 实现了基于Chunk内存管理和异构内存管理的 Zero-3。
- [Hybrid Pararllel 插件](#hybrid-parallel-插件): 它为Shardformer流水线管理器混合精度运算TorchDDP以及Zero-1/Zero-2功能提供了一个统一且简洁的接口。使用该插件可以简单高效地实现transformer模型在张量并行流水线并行以及数据并行DDP, Zero间任意组合并行训练策略同时支持多种训练速度和内存的优化工具。有关这些训练策略和优化工具的具体信息将在下一章中阐述。
更多插件即将推出。
## 插件选择
- [Torch DDP 插件](#torch-ddp-插件): 适用于参数少于 20 亿的模型(例如 Bert-3m、GPT2-1.5b)。
- [Torch FSDP 插件](#torch-fsdp-插件) / [Low Level Zero 插件](#low-level-zero-插件): 适用于参数少于 100 亿的模型(例如 GPTJ-6b、MegatronLM-8b
- [Gemini 插件](#gemini-插件): 适合参数超过 100 亿的模型(例如 TuringNLG-17b且**跨节点带宽高、中小规模集群(千卡以下)**的场景(例如 Llama2-70b
- [Hybrid Pararllel 插件](#hybrid-parallel-插件): 适合参数超过 600 亿的模型、超长序列、超大词表等特殊模型,且**跨节点带宽低、大规模集群(千卡以上)**的场景(例如 GPT3-175b、Bloom-176b
## 插件
### Low Level Zero 插件
@ -50,6 +57,23 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累
{{ autodoc:colossalai.booster.plugin.GeminiPlugin }}
### Hybrid Parallel 插件
这个插件实现了多种并行训练策略和优化工具的组合。Hybrid Parallel插件支持的功能大致可以被分为以下四个部分
1. Shardformer: Shardformer负责在张量并行以及流水线并行下切分模型的逻辑以及前向/后向方法的重载这个插件为Shardformer功能提供了一个简单易用的接口。与此同时Shardformer还负责将包括fused normalization, flash attention (xformers), JIT和序列并行在内的各类优化工具融入重载后的前向/后向方法。更多关于Shardformer的信息请参考 [Shardformer文档](../features/shardformer.md)。
2. 混合精度训练插件支持fp16/bf16的混合精度训练。更多关于混合精度训练的参数配置的详细信息请参考 [混合精度训练文档](../features/mixed_precision_training_with_booster.md)。
3. Torch DDP: 当流水线并行和Zero不被使用的时候插件会自动采用Pytorch DDP作为数据并行的策略。更多关于Torch DDP的参数配置的详细信息请参考 [Pytorch DDP 文档](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel)。
4. Zero: 在初始化插件的时候,可以通过将`zero_stage`参数设置为1或2来让插件采用Zero 1/2作为数据并行的策略。Zero 1可以和流水线并行策略同时使用, 而Zero 2则不可以和流水线并行策略同时使用。更多关于Zero的参数配置的详细信息请参考 [Low Level Zero 插件](#low-level-zero-插件).
> ⚠ 在使用该插件的时候, 只有支持Shardformer的部分Huggingface transformers模型才能够使用张量并行、流水线并行以及优化工具。Llama 1、Llama 2、OPT、Bloom、Bert以及GPT2等主流transformers模型均已支持Shardformer。
> ⚠ 该插件当前只对模型和优化器支持分片的checkpoint方法。不分片的checkpoint方法会在未来的版本中被支持。
{{ autodoc:colossalai.booster.plugin.HybridParallelPlugin }}
### Torch DDP 插件
@ -69,23 +93,4 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累
{{ autodoc:colossalai.booster.plugin.TorchFSDPPlugin }}
### Hybrid Parallel 插件
这个插件实现了多种并行训练策略和优化工具的组合。Hybrid Parallel插件支持的功能大致可以被分为以下四个部分
1. Shardformer: Shardformer负责在张量并行以及流水线并行下切分模型的逻辑以及前向/后向方法的重载这个插件为Shardformer功能提供了一个简单易用的接口。与此同时Shardformer还负责将包括fused normalization, flash attention (xformers), JIT和序列并行在内的各类优化工具融入重载后的前向/后向方法。更多关于Shardformer的信息请参考 [Shardformer文档](../features/shardformer.md)。
2. 混合精度训练插件支持fp16/bf16的混合精度训练。更多关于混合精度训练的参数配置的详细信息请参考 [混合精度训练文档](../features/mixed_precision_training_with_booster.md)。
3. Torch DDP: 当流水线并行和Zero不被使用的时候插件会自动采用Pytorch DDP作为数据并行的策略。更多关于Torch DDP的参数配置的详细信息请参考 [Pytorch DDP 文档](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel)。
4. Zero: 在初始化插件的时候,可以通过将`zero_stage`参数设置为1或2来让插件采用Zero 1/2作为数据并行的策略。Zero 1可以和流水线并行策略同时使用, 而Zero 2则不可以和流水线并行策略同时使用。更多关于Zero的参数配置的详细信息请参考 [Low Level Zero 插件](#low-level-zero-插件).
> ⚠ 在使用该插件的时候, 只有支持Shardformer的部分Huggingface transformers模型才能够使用张量并行、流水线并行以及优化工具。Llama 1、Llama 2、OPT、Bloom、Bert以及GPT2等主流transformers模型均已支持Shardformer。
> ⚠ 该插件当前只对模型和优化器支持分片的checkpoint方法。不分片的checkpoint方法会在未来的版本中被支持。
{{ autodoc:colossalai.booster.plugin.HybridParallelPlugin }}
<!-- doc-test-command: echo -->