@ -214,9 +214,56 @@ In addition, xFormers's `cutlass_op` can serve as a backup for flash attention.
Enabling `Shardformer` through `Booster` initialized with `HybridParallelPlugin` is the recommended way to awaken the power of Shardformer.
The main reason is that pipeline parallelism cannot successfully work without the calling of `execute_pipeline` method of `Booster`. Besides, `HybridParallelPlugin` provides the capacity to combine the features of `Shardformer` with other useful features, such as mixed precision training or Zero.
More details about this usage can be found in chapter [Booster API](../basics/booster_api.md) and [Booster Plugins](../basics/booster_plugins.md).
[Here](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) is an example on how to trigger `Shardformer` through `HybridParallelPlugin`. Move to the root directory of this example, and execute
Then you can start finetuning a bert model wrapped by `Shardformer`. The process of wrapping is operated by `HybridParallelPlugin`.
Let's delve into the code of `finetune.py`:
In the `main` function, the plugin is created through the following codes:
```python
...
elif args.plugin == "hybrid_parallel":
# modify the param accordingly for finetuning test cases
plugin = HybridParallelPlugin(
tp_size=1,
pp_size=2,
num_microbatches=None,
microbatch_size=1,
enable_all_optimization=True,
zero_stage=1,
precision="fp16",
initial_scale=1,
)
```
Here you can change the configuration of plugin by setting `tp_size`, `pp_size` or `zero_stage` to other values. More details about plugin configuration can be found in [Booster Plugins Doc](../basics/booster_plugins.md).
If pipeline parallel is not enabled, just do the training in the same way of other booster plugins(first boost with Booster, then do forward and backward through normal way).
However, if pipeline parallel is enabled, there are several usages different from other normal cases:
1. Before doing forward or backward, the criterion function (loss function) is processed to meet the argument demand of running pipeline:
```python
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
```
2. In `train_epoch` function, dataloader is converted into `Iterator` class before running pipeline:
```python
train_dataloader_iter = iter(train_dataloader)
```
[Here](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert) is an example on how to trigger `Shardformer` through `HybridParallelPlugin`. Please be aware that there's a difference in the way of doing forward and backward between the situation of using pipeline and not using pipeline.
3. Do forward and backward passing through calling `Booster.execute_pipeline` method:
Backward passing has been completed by this method, so there is no need to call `loss.backward()` after executing this method.
More details about `Booster.execute_pipeline` can be found in [Booster API Doc](../basics/booster_api.md).
#### 2. Enabling Shardformer Through Shardformer APIs (Not Recommended)
@ -224,7 +271,26 @@ More details about this usage can be found in chapter [Booster API](../basics/bo
You can also use Shardformer through manually calling Shardformer APIs. However, this usage is not recommended since pipeline parallelism can't run without `Booster`.
is an example on how to trigger `Shardformer` through calling Shardformer APIs.
is an example on how to trigger `Shardformer` through calling Shardformer APIs. In the `train` function of example code, the model is wrapped by `Shardformer` through the following few codes:
```python
...
if dist.get_world_size() > 1:
tp_group = dist.new_group(backend="nccl")
# First create configuration for Shardformer
shard_config = ShardConfig(
tensor_parallel_process_group=tp_group,
enable_tensor_parallelism=True,
enable_all_optimization=True
)
# Then create ShardFormer object with created config
更多关于这一用法的细节可以参考 [Booster API 文档](../basics/booster_api.md)以及[Booster 插件文档](../basics/booster_plugins.md)。[这里](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/bert)是一个通过`HybridParallelPlugin`启动`Shardformer`的示例。