mirror of https://github.com/hpcaitech/ColossalAI
235 lines
8.3 KiB
Markdown
235 lines
8.3 KiB
Markdown
# Parallelization
|
||
|
||
## Configure the Combination of Parallelization
|
||
|
||
We support multiple parallelization in our library.
|
||
|
||
Hybrid parallelism in our codebase refers to namely the combination of data parallelism, pipeline parallelism
|
||
and tensor parallelism (1D, 2D, 2.5D, 3D). Each parallelism requires different network topology and thus
|
||
different initializers for distributed process group. You can initialize the corresponding process group by
|
||
setting `parallel` in our config. The parallel configuration can be easily deployed by a dictionary in
|
||
configuration file. The configuration dictionary must obey the following format. Data parallel size will be
|
||
inferred automatically based on your inputs to pipeline parallelism and tensor parallelism. The distributed
|
||
environment will set up by `colossalai.launch`.
|
||
|
||
```python
|
||
# sampler format
|
||
parallel = dict(
|
||
pipeline=dict("size": int),
|
||
tensor=dict("size": int, "mode": '1d' or '2d' or '2.5d' or '3d', "kwargs": Any)
|
||
)
|
||
|
||
# this is ok
|
||
parallel = dict(
|
||
pipeline=dict(size=2),
|
||
tensor=dict(size=4, mode='2d')
|
||
)
|
||
|
||
# this is ok
|
||
parallel = dict(
|
||
pipeline=2,
|
||
tensor=dict(size=4, mode='2d')
|
||
)
|
||
|
||
# this is not ok
|
||
# as you need to specify the mode for tensor parallelism
|
||
parallel = dict(
|
||
pipeline=2,
|
||
tensor=4
|
||
)
|
||
|
||
# this is ok as well as tensor will be default to size 1
|
||
# and mode None
|
||
parallel = dict(
|
||
pipeline=2
|
||
)
|
||
|
||
# this is ok as well as pipeline will default to size 1
|
||
parallel = dict(
|
||
tensor=dict(size=4, mode='2d')
|
||
)
|
||
|
||
```
|
||
|
||
The name of the dictionary variable should be **parallel**. All the arguments even **parallel** itself are optional and
|
||
data, pipeline, tensor parallel size will be set to defaulted value 1. The value of data, pipeline and tensor can be a
|
||
int representing the size of specific parallel dimension or a dictionary with a key called "size". The key "mode"
|
||
represents the way of tensor parallelism.
|
||
|
||
**You can choose to not have 'parallel' in your configuration and both pipelineand tensor will default to size 1.**
|
||
|
||
|
||
## Data Parallel
|
||
|
||
Data parallel is the most common way to distribute your training task by splitting data into several shards and train on
|
||
a single shard on each device. The configuration for data parallel is detected automatically and set for you. You do not
|
||
have to explicitly set them in your configurations. There are two ways to handle the all-reduce in data parallel in Colossal-AI.
|
||
|
||
1. If you specify gradient handlers, gradients will be all-reduced according to the gradient handlers
|
||
2. Otherwise, PyTorch DistributedDataParallel will be used
|
||
|
||
In most cases, you will be using the second mode unless you have complex handling of the gradients.
|
||
|
||
## 1D, 2D, 2.5D and 3D Parallel
|
||
|
||
To enable hybrid parallelism, we provide an array of tensor parallelism. We provide the list of papers which match each
|
||
tensor parallel method. These parallel modes need to work with the distributed layers provided by Colossal-AI.
|
||
|
||
- 1D: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053)
|
||
|
||
- 2D: [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/abs/2104.05343)
|
||
2D parallel relies on the SUMMA matrix multiplication algorithm and splits the input data, model weights and layer
|
||
outputs along two different dimensions. The tensor chunks are distributed over a 2D mesh of $P = N^2$ devices where
|
||
$N$ is the number of tensor chunks in a single dimension.
|
||
|
||
- 2.5D: [2.5-dimensional distributed model training](https://arxiv.org/abs/2105.14500)
|
||
Inspired by the 2.5D matrix multiplication algorithm, 2.5D parallel introduces a novel tensor parallelism which
|
||
further parallelizes 2D tensor parallelism. An amount of $P = N^2 ∗ d$ processors are arranged into $d$ layers, where
|
||
each layer performs matrix multiplication operations independently with a dimension $N$.
|
||
|
||
- 3D: [Maximizing Parallelism in Distributed Training for Huge Neural Networks](https://arxiv.org/abs/2105.14450)
|
||
We also introduce a 3D tensor parallelism that parallelizes neural networks on a 3D processor cube. This method
|
||
achieves the optimal, $O(P^{1/3})$ communication overhead on $P$ processors, while both computation and memory usage
|
||
are evenly distributed through optimized load balancing of parameters as well as activations.
|
||
|
||
```python
|
||
# 1D parallel
|
||
parallel = dict(
|
||
tensor=dict(size=4, mode='1d')
|
||
)
|
||
|
||
# 2D parallel
|
||
parallel = dict(
|
||
tensor=dict(size=4, mode='2d')
|
||
)
|
||
|
||
# 2.5D parallel
|
||
parallel = dict(
|
||
tensor=dict(size=8, mode='2.5d', depth=2)
|
||
)
|
||
|
||
# 3D parallel
|
||
parallel = dict(
|
||
tensor=dict(size=8, mode='3d')
|
||
)
|
||
```
|
||
|
||
Once you specify the tensor parallel mode in your configuration, you can proceed to use its corresponding distributed
|
||
operator. For example, if you mode is '2d', you can use `colossalai.nn.Linear2D` in you model construction.
|
||
|
||
|
||
## Pipeline Parallel (experimental)
|
||
|
||
Pipeline parallelism is to split the model into several partitions by layer. For example, let's assume we have a simple
|
||
model which consists of two linear layer. We have two GPUs, and we can allocate the first linear layer to the first GPU
|
||
and the second layer to the second GPU.
|
||
|
||
You can set the number of pipeline stages in your configuration file. When pipeline size is larger than 1, Colossal-AI
|
||
will automatically creates the pipeline schedule which defines the forward and backward step.
|
||
|
||
```python
|
||
parallel = dict(
|
||
pipeline=dict(size=4), # number of pipeline stages
|
||
)
|
||
```
|
||
|
||
As PyTorch is based on dynamic computation graph, the computation flow is not known until execution. To support pipeline parallelism, you have the following two ways to split your model,
|
||
|
||
1. Split your model directly. Below is an exmaple of resnet split into two pipeline stages.
|
||
```python
|
||
from torchvision.models import resnet18
|
||
from colossalai.core import global_context as gpc
|
||
|
||
model = resnet18(num_classes=10)
|
||
|
||
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
|
||
model = nn.Sequential(
|
||
model.conv1,
|
||
model.bn1,
|
||
model.relu,
|
||
model.maxpool,
|
||
model.layer1,
|
||
model.layer2
|
||
)
|
||
elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1:
|
||
from functools import partial
|
||
|
||
class Flatten(nn.Module):
|
||
|
||
def forward(self, x):
|
||
return torch.flatten(x, 1)
|
||
|
||
model = nn.Sequential(
|
||
model.layer3,
|
||
model.layer4,
|
||
model.avgpool,
|
||
Flatten(),
|
||
model.fc
|
||
)
|
||
```
|
||
|
||
|
||
2. Make sure your model inherit `colossalai.nn.model.ModelFromConfig` and registered into the
|
||
`MODELS` registry. Define the `self.layers_cfg` attribute.
|
||
Pass in a dict/Config object which specifies the parameters of your model.
|
||
Use `colossalai.builder.pipeline.PipelineModelInitializer` to partition the layers.
|
||
|
||
```python
|
||
from colossalai.builder import PipelineModelInitializer
|
||
from colossalai.nn.model import ModelFromConfig
|
||
from colossalai.registry import MODELS
|
||
|
||
|
||
@MODELS.register_module
|
||
class MyModel(ModelFromConfig):
|
||
|
||
def __init__(self, arg1, arg2, ...):
|
||
...
|
||
self.layers_cfg = [
|
||
dict(type='Linear', in_features=3, out_features=512),
|
||
dict(type='Linear', in_features=512, out_features=512),
|
||
...
|
||
]
|
||
|
||
|
||
model_cfg = dict(
|
||
type='MyModel',
|
||
arg1=1,
|
||
arg2=2
|
||
...
|
||
)
|
||
|
||
initializer = PipelineModelInitializer(model_cfg, num_chunks=1)
|
||
model = initializer.initialize()
|
||
|
||
```
|
||
|
||
When your model is split into partitions, you can use PipelineSchedule to execute training.
|
||
|
||
```python
|
||
import colossalai
|
||
from colossalai.engine.schedule import PipelineSchedule
|
||
|
||
engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader)
|
||
|
||
schedule = PipelineSchedule(num_microbatches=4)
|
||
|
||
# execute a training epoch
|
||
data_iter = iter(train_dataloader)
|
||
|
||
for i in range(len(train_dataloader)):
|
||
output, label, loss = schedule.forward_backward_step(engine,
|
||
data_iter,
|
||
forward_only=False,
|
||
)
|
||
|
||
```
|
||
|
||
This feature is still in development and is only experimental for now.
|
||
|
||
## Sequence Parallel (experimental)
|
||
|
||
Sequence parallel is to support long-sequence modelling such as document-level text understanding and medical imaging.
|
||
This method is proposed in [Sequence Parallelism: Making 4D Parallelism Possible](https://arxiv.org/abs/2105.13120).
|
||
This feature is still in development and is only experimental for now.
|