[doc] polish shardformer doc (#4779)

* fix example format in docstring

* polish shardformer doc
pull/4807/head^2
Baizhou Zhang 1 year ago committed by GitHub
parent 26cd6d850c
commit a2db75546d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -229,16 +229,17 @@ class GeminiPlugin(DPPluginBase):
""" """
Plugin for Gemini. Plugin for Gemini.
Example: ```python
>>> from colossalai.booster import Booster from colossalai.booster import Booster
>>> from colossalai.booster.plugin import GeminiPlugin from colossalai.booster.plugin import GeminiPlugin
>>>
>>> model, train_dataset, optimizer, criterion = ... model, train_dataset, optimizer, criterion = ...
>>> plugin = GeminiPlugin() plugin = GeminiPlugin()
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
```
Args: Args:
chunk_config_dict (dict, optional): chunk configuration dictionary. chunk_config_dict (dict, optional): chunk configuration dictionary.

@ -266,16 +266,17 @@ class HybridParallelPlugin(PipelinePluginBase):
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin. Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size). The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
Example: ```python
>>> from colossalai.booster import Booster from colossalai.booster import Booster
>>> from colossalai.booster.plugin import HybridParallelPlugin from colossalai.booster.plugin import HybridParallelPlugin
>>> model, train_dataset, optimizer, criterion = ... model, train_dataset, optimizer, criterion = ...
>>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2) plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader) model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
```
Args: Args:
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1. tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.

@ -213,16 +213,17 @@ class LowLevelZeroPlugin(DPPluginBase):
""" """
Plugin for low level zero. Plugin for low level zero.
Example: ```python
>>> from colossalai.booster import Booster from colossalai.booster import Booster
>>> from colossalai.booster.plugin import LowLevelZeroPlugin from colossalai.booster.plugin import LowLevelZeroPlugin
>>>
>>> model, train_dataset, optimizer, criterion = ... model, train_dataset, optimizer, criterion = ...
>>> plugin = LowLevelZeroPlugin() plugin = LowLevelZeroPlugin()
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
```
Args: Args:
strage (int, optional): ZeRO stage. Defaults to 1. strage (int, optional): ZeRO stage. Defaults to 1.

@ -130,16 +130,17 @@ class TorchDDPPlugin(DPPluginBase):
""" """
Plugin for PyTorch DDP. Plugin for PyTorch DDP.
Example: ```python
>>> from colossalai.booster import Booster from colossalai.booster import Booster
>>> from colossalai.booster.plugin import TorchDDPPlugin from colossalai.booster.plugin import TorchDDPPlugin
>>>
>>> model, train_dataset, optimizer, criterion = ... model, train_dataset, optimizer, criterion = ...
>>> plugin = TorchDDPPlugin() plugin = TorchDDPPlugin()
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8) train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
```
Args: Args:
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Defaults to True. broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training. Defaults to True.

@ -143,16 +143,17 @@ class TorchFSDPPlugin(DPPluginBase):
""" """
Plugin for PyTorch FSDP. Plugin for PyTorch FSDP.
Example: ```python
>>> from colossalai.booster import Booster from colossalai.booster import Booster
>>> from colossalai.booster.plugin import TorchFSDPPlugin from colossalai.booster.plugin import TorchFSDPPlugin
>>>
>>> model, train_dataset, optimizer, criterion = ... model, train_dataset, optimizer, criterion = ...
>>> plugin = TorchFSDPPlugin() plugin = TorchFSDPPlugin()
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8) train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion) model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
```
Args: Args:
See https://pytorch.org/docs/stable/fsdp.html for details. See https://pytorch.org/docs/stable/fsdp.html for details.

@ -20,14 +20,16 @@ class DistCoordinator(metaclass=SingletonMeta):
- master: the process with rank 0 - master: the process with rank 0
- node master: the process with local rank 0 on the current node - node master: the process with local rank 0 on the current node
Example:
>>> from colossalai.cluster.dist_coordinator import DistCoordinator ```python
>>> coordinator = DistCoordinator() from colossalai.cluster.dist_coordinator import DistCoordinator
>>> coordinator = DistCoordinator()
>>> if coordinator.is_master():
>>> do_something() if coordinator.is_master():
>>> do_something()
>>> coordinator.print_on_master('hello world')
coordinator.print_on_master('hello world')
```
Attributes: Attributes:
rank (int): the rank of the current process rank (int): the rank of the current process
@ -131,11 +133,13 @@ class DistCoordinator(metaclass=SingletonMeta):
other processes in the same process group. This is often useful when downloading is required other processes in the same process group. This is often useful when downloading is required
as we only want to download in one process to prevent file corruption. as we only want to download in one process to prevent file corruption.
Example:
>>> from colossalai.cluster import DistCoordinator ```python
>>> dist_coordinator = DistCoordinator() from colossalai.cluster import DistCoordinator
>>> with dist_coordinator.priority_execution(): dist_coordinator = DistCoordinator()
>>> dataset = CIFAR10(root='./data', download=True) with dist_coordinator.priority_execution():
dataset = CIFAR10(root='./data', download=True)
```
Args: Args:
executor_rank (int): the process rank to execute without blocking, all other processes will be blocked executor_rank (int): the process rank to execute without blocking, all other processes will be blocked
@ -174,13 +178,14 @@ class DistCoordinator(metaclass=SingletonMeta):
""" """
A function wrapper that only executes the wrapped function on the master process (rank 0). A function wrapper that only executes the wrapped function on the master process (rank 0).
Example: ```python
>>> from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
>>> dist_coordinator = DistCoordinator() dist_coordinator = DistCoordinator()
>>>
>>> @dist_coordinator.on_master_only() @dist_coordinator.on_master_only()
>>> def print_on_master(msg): def print_on_master(msg):
>>> print(msg) print(msg)
```
""" """
is_master = self.is_master(process_group) is_master = self.is_master(process_group)

@ -214,9 +214,56 @@ In addition, xFormers's `cutlass_op` can serve as a backup for flash attention.
Enabling `Shardformer` through `Booster` initialized with `HybridParallelPlugin` is the recommended way to awaken the power of Shardformer. Enabling `Shardformer` through `Booster` initialized with `HybridParallelPlugin` is the recommended way to awaken the power of Shardformer.
The main reason is that pipeline parallelism cannot successfully work without the calling of `execute_pipeline` method of `Booster`. Besides, `HybridParallelPlugin` provides the capacity to combine the features of `Shardformer` with other useful features, such as mixed precision training or Zero. The main reason is that pipeline parallelism cannot successfully work without the calling of `execute_pipeline` method of `Booster`. Besides, `HybridParallelPlugin` provides the capacity to combine the features of `Shardformer` with other useful features, such as mixed precision training or Zero.
More details about this usage can be found in chapter [Booster API](../basics/booster_api.md) and [Booster Plugins](../basics/booster_plugins.md). [Here](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) is an example on how to trigger `Shardformer` through `HybridParallelPlugin`. Move to the root directory of this example, and execute
```bash
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin "hybrid_parallel" --model_type "bert"
```
Then you can start finetuning a bert model wrapped by `Shardformer`. The process of wrapping is operated by `HybridParallelPlugin`.
Let's delve into the code of `finetune.py`:
In the `main` function, the plugin is created through the following codes:
```python
...
elif args.plugin == "hybrid_parallel":
# modify the param accordingly for finetuning test cases
plugin = HybridParallelPlugin(
tp_size=1,
pp_size=2,
num_microbatches=None,
microbatch_size=1,
enable_all_optimization=True,
zero_stage=1,
precision="fp16",
initial_scale=1,
)
```
Here you can change the configuration of plugin by setting `tp_size`, `pp_size` or `zero_stage` to other values. More details about plugin configuration can be found in [Booster Plugins Doc](../basics/booster_plugins.md).
If pipeline parallel is not enabled, just do the training in the same way of other booster plugins(first boost with Booster, then do forward and backward through normal way).
However, if pipeline parallel is enabled, there are several usages different from other normal cases:
1. Before doing forward or backward, the criterion function (loss function) is processed to meet the argument demand of running pipeline:
```python
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
```
2. In `train_epoch` function, dataloader is converted into `Iterator` class before running pipeline:
```python
train_dataloader_iter = iter(train_dataloader)
```
[Here](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) is an example on how to trigger `Shardformer` through `HybridParallelPlugin`. Please be aware that there's a difference in the way of doing forward and backward between the situation of using pipeline and not using pipeline. 3. Do forward and backward passing through calling `Booster.execute_pipeline` method:
```python
outputs = booster.execute_pipeline(
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
)
```
Backward passing has been completed by this method, so there is no need to call `loss.backward()` after executing this method.
More details about `Booster.execute_pipeline` can be found in [Booster API Doc](../basics/booster_api.md).
#### 2. Enabling Shardformer Through Shardformer APIs (Not Recommended) #### 2. Enabling Shardformer Through Shardformer APIs (Not Recommended)
@ -224,7 +271,26 @@ More details about this usage can be found in chapter [Booster API](../basics/bo
You can also use Shardformer through manually calling Shardformer APIs. However, this usage is not recommended since pipeline parallelism can't run without `Booster`. You can also use Shardformer through manually calling Shardformer APIs. However, this usage is not recommended since pipeline parallelism can't run without `Booster`.
[Here](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py) [Here](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py)
is an example on how to trigger `Shardformer` through calling Shardformer APIs. is an example on how to trigger `Shardformer` through calling Shardformer APIs. In the `train` function of example code, the model is wrapped by `Shardformer` through the following few codes:
```python
...
if dist.get_world_size() > 1:
tp_group = dist.new_group(backend="nccl")
# First create configuration for Shardformer
shard_config = ShardConfig(
tensor_parallel_process_group=tp_group,
enable_tensor_parallelism=True,
enable_all_optimization=True
)
# Then create ShardFormer object with created config
shard_former = ShardFormer(shard_config=shard_config)
# Finally shard the model using ShardFormer.optimize method
model, _ = shard_former.optimize(model)
...
```
### Precautions ### Precautions
@ -241,6 +307,8 @@ is an example on how to trigger `Shardformer` through calling Shardformer APIs.
## How Shardformer Works ## How Shardformer Works
### Main Idea
Generally, Shardformer works through the following four kinds of *replacements*: Generally, Shardformer works through the following four kinds of *replacements*:
1. Replacing original PyTorch module (e.g. `nn.Linear`, `nn.Embedding`) with a crafted distributed module. 1. Replacing original PyTorch module (e.g. `nn.Linear`, `nn.Embedding`) with a crafted distributed module.

@ -207,8 +207,56 @@ Shardformer的配置由类`ShardConfig`的参数控制:
通过用`HybridParallelPlugin`初始化的`Booster`来启动`Shardformer`是最推荐的用法。其主要原因是如果不调用`Booster`的`execute_pipeline`方法,流水线并行就无法正常工作。此外,`HybridParallelPlugin`提供了将`Shardformer`的功能与其他功能例如混合精度训练或Zero相结合的能力。 通过用`HybridParallelPlugin`初始化的`Booster`来启动`Shardformer`是最推荐的用法。其主要原因是如果不调用`Booster`的`execute_pipeline`方法,流水线并行就无法正常工作。此外,`HybridParallelPlugin`提供了将`Shardformer`的功能与其他功能例如混合精度训练或Zero相结合的能力。
更多关于这一用法的细节可以参考 [Booster API 文档](../basics/booster_api.md)以及[Booster 插件文档](../basics/booster_plugins.md)。[这里](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert)是一个通过`HybridParallelPlugin`启动`Shardformer`的示例。 [这里](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert)是一个通过`HybridParallelPlugin`启动`Shardformer`的示例。
移动到示例的根目录下,执行命令:
```bash
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin "hybrid_parallel" --model_type "bert"
```
你便可以微调一个被`Shardformer`封装过的Bert模型而封装的操作是由`HybridParallelPlugin`完成的。
接下来一起深入挖掘一下`finetune.py`里的代码:
在`main`函数中,混合并行的插件通过以下的代码创建
```python
...
elif args.plugin == "hybrid_parallel":
# modify the param accordingly for finetuning test cases
plugin = HybridParallelPlugin(
tp_size=1,
pp_size=2,
num_microbatches=None,
microbatch_size=1,
enable_all_optimization=True,
zero_stage=1,
precision="fp16",
initial_scale=1,
)
```
在这里你可以通过设置不同的`tp_size`, `pp_size``zero_stage`来改变插件的配置。更多关于插件配置的信息可以在[Booster 插件文档](../basics/booster_plugins.md)中被找到。
当流水并行不被启用的时候,训练的流程和其他的插件是一样的 先用Booster封装模型和优化器再用正常的方式做前向和后向传递。然而当流水线并行被启用的时候有几处不同于寻常情况的用法
1. 在进行前向和后向之前criterion函数loss函数需要被处理以满足流水线并行的传参要求:
```python
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
```
2. 在 `train_epoch` 函数中, dataloader 在进行流水线的前向后向操作之前需要被转换为 `Iterator` 类:
```python
train_dataloader_iter = iter(train_dataloader)
```
3. 通过调用`Booster.execute_pipeline` 方法来执行前向和后向传递:
```python
outputs = booster.execute_pipeline(
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
)
```
该方法会自动执行后向传递,所以在执行该方法后不需要再调用 `loss.backward()`方法。
更多关于 `Booster.execute_pipeline` 的信息可以参考 [Booster API 文档](../basics/booster_api.md)。
#### 2. 通过Shardformer API启动Shardformer (不推荐) #### 2. 通过Shardformer API启动Shardformer (不推荐)
@ -216,7 +264,26 @@ Shardformer的配置由类`ShardConfig`的参数控制:
[这里](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py) [这里](https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/shardformer/examples/convergence_benchmark.py)
是一个通过调用Shardformer的API启动`Shardformer`的示例。 是一个通过调用Shardformer的API启动`Shardformer`的示例。
在示例代码的`train`函数中,模型被以下的几行代码进行封装:
```python
...
if dist.get_world_size() > 1:
tp_group = dist.new_group(backend="nccl")
# First create configuration for Shardformer
shard_config = ShardConfig(
tensor_parallel_process_group=tp_group,
enable_tensor_parallelism=True,
enable_all_optimization=True
)
# Then create ShardFormer object with created config
shard_former = ShardFormer(shard_config=shard_config)
# Finally shard the model using ShardFormer.optimize method
model, _ = shard_former.optimize(model)
...
```
### 注意事项 ### 注意事项
@ -234,6 +301,8 @@ Shardformer的配置由类`ShardConfig`的参数控制:
## Shardformer的工作原理 ## Shardformer的工作原理
### 设计思想
通常来说Shardformer通过以下四种“替换”进行工作 通常来说Shardformer通过以下四种“替换”进行工作
1. 用我们设计的分布式模块替换原始的PyTorch模块例如`nn.Linear`、`nn.Embedding`)。 1. 用我们设计的分布式模块替换原始的PyTorch模块例如`nn.Linear`、`nn.Embedding`)。

Loading…
Cancel
Save