update markdown docs (english) (#60)

pull/63/head
Frank Lee 2021-12-10 14:37:33 +08:00 committed by GitHub
parent da01c234e1
commit 9a0466534c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 341 additions and 374 deletions

View File

@ -42,21 +42,56 @@ pip install -v --no-cache-dir --global-option="--cuda_ext" .
```python
import colossalai
from colossalai.trainer import Trainer
from colossalai.core import global_context as gpc
from colossalai.utils import get_dataloader
engine, train_dataloader, test_dataloader = colossalai.initialize()
trainer = Trainer(engine=engine,
verbose=True)
trainer.fit(
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
epochs=gpc.config.num_epochs,
hooks_cfg=gpc.config.hooks,
display_progress=True,
test_interval=5
# my_config can be path to config file or a dictionary obj
# 'localhost' is only for single node, you need to specify
# the node name if using multiple nodes
colossalai.launch(
config=my_config,
rank=rank,
world_size=world_size,
backend='nccl',
port=29500,
host='localhost'
)
# build your model
model = ...
# build you dataset, the dataloader will have distributed data
# sampler by default
train_dataset = ...
train_dataloader = get_dataloader(dataset=dataset,
shuffle=True,
)
# build your
optimizer = ...
# build your loss function
criterion = ...
# build your lr_scheduler
engine, train_dataloader, _, _ = colossalai.initialize(
model=model,
optimizer=optimizer,
criterion=criterion,
train_dataloader=train_dataloader
)
# start training
engine.train()
for epoch in range(NUM_EPOCHS):
for data, label in train_dataloader:
engine.zero_grad()
output = engine(data)
loss = engine.criterion(output, label)
engine.backward(loss)
engine.step()
```
### Write a Simple 2D Parallel Model

View File

@ -1,6 +1,7 @@
from ._base_hook import BaseHook
from ._checkpoint_hook import SaveCheckpointHook, LoadCheckpointHook
from ._metric_hook import LossHook, Accuracy2DHook, AccuracyHook, MetricHook
from ._metric_hook import (LossHook, Accuracy2DHook, AccuracyHook, MetricHook,
Accuracy1DHook, Accuracy2p5DHook, Accuracy3DHook)
from ._log_hook import LogMetricByEpochHook, TensorboardHook, LogTimingByEpochHook, LogMemoryByEpochHook
from ._lr_scheduler_hook import LRSchedulerHook
@ -8,6 +9,7 @@ __all__ = [
'BaseHook', 'MetricHook',
'LoadCheckpointHook', 'SaveCheckpointHook',
'LossHook', 'AccuracyHook', 'Accuracy2DHook',
'Accuracy1DHook', 'Accuracy2p5DHook', 'Accuracy3DHook',
'LogMetricByEpochHook', 'TensorboardHook', 'LogTimingByEpochHook', 'LogMemoryByEpochHook',
'LRSchedulerHook'
]

View File

@ -108,7 +108,14 @@ class DataParallelSampler(Sampler):
self.epoch = epoch
def get_dataloader(dataset, shuffle=False, seed=1024, add_sampler=True, **kwargs):
def get_dataloader(dataset,
shuffle=False,
seed=1024,
add_sampler=True,
drop_last=False,
pin_memory=False,
num_workers=0,
**kwargs):
'''Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
.. note: when pipeline parallel is enabled, shuffle cannot be True
@ -141,9 +148,15 @@ def get_dataloader(dataset, shuffle=False, seed=1024, add_sampler=True, **kwargs
return DataLoader(dataset,
worker_init_fn=seed_worker,
shuffle=shuffle,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs)
else:
return DataLoader(dataset,
sampler=sampler,
worker_init_fn=seed_worker,
drop_last=drop_last,
pin_memory=pin_memory,
num_workers=num_workers,
**_kwargs)

View File

@ -103,7 +103,7 @@ class YourGradientHandler(BaseGradientHandler):
Afterwards, you can specify the gradient handler you want to use in your configuration file.
```python
dist_initializer = [
gradient_handlers = [
dict(type='YourGradientHandler'),
]
```
@ -112,5 +112,4 @@ dist_initializer = [
Schedule entails how to execute a forward and backward pass. Currently, Colossal-AI provides pipeline and non-pipeline
schedules. If you want to modify how the forward and backward passes are executed, you can
inherit `colossalai.engine.BaseSchedule` and implement your idea. You can also add your schedule to the engine before
training.
inherit `colossalai.engine.schedule.BaseSchedule` and implement the `forward_back_step` function.

View File

@ -3,17 +3,31 @@
In Colossal-AI, we have incorporated different implementations of mixed precision training:
1. torch.cuda.amp
2. apex.amp
3. tensor-parallel amp
3. naive amp
The first two rely on the original implementation of [PyTorch](https://pytorch.org/docs/stable/amp.html)
(version 1.6 and above) and [Nvidia Apex](https://github.com/NVIDIA/apex). However, these two methods are not compatible
with tensor parallelism. This is because that tensors are split across devices in tensor parallelism, thus, it is required
to communicate among different processes to check if `inf` or `nan` occurs in the whole model weights. For the mixed
precision training with tensor parallelism, we adapted this feature from [Megatron-LM](https://github.com/NVIDIA/Megatron-LM).
(version 1.6 and above) and [Nvidia Apex](https://github.com/NVIDIA/apex). The last mehtod is simialr to Apex O2 level.
Among these methods, apex.amp is not compatible with tensor parallelism. This is because that tensors are split across devices
in tensor parallelism, thus, it is required to communicate among different processes to check if `inf` or `nan` occurs in the
whole model weights. **We modified the torch amp implementation so that it is compatible with tensor parallelism now.**
To use mixed precision training, you can easily specify the `fp16` field in the config file to be True. Currently, PyTorch and
Apex amp cannot be guaranteed to work with tensor and pipeline parallelism, thus, only the last one is recommended if you
are using hybrid parallelism.
Apex amp cannot be guaranteed to work with tensor and pipeline parallelism. We recommend you to use torch amp as it generally
gives better accuracy than naive amp.
The AMP module is designed to be completely modular and can be used independently from other colossalai modules.
If you wish to only use amp in your code base without `colossalai.initialize`, you can use `colossalai.amp.convert_to_amp`.
```python
from colossalai.amp import AMP_TYPE
# exmaple of using torch amp
model, optimizer, criterion = colossalai.amp.convert_to_amp(model,
optimizer,
criterion,
AMP_TYPE.TORCH)
```
## PyTorch AMP
@ -21,7 +35,7 @@ PyTorch provides mixed precision training in version 1.6 and above. It provides
while keeping some operations such as reductions in `fp32`. You can configure the gradient scaler in the config file.
```python
from colossalai.engine import AMP_TYPE
from colossalai.amp import AMP_TYPE
fp16=dict(
mode=AMP_TYPE.TORCH,
@ -43,7 +57,7 @@ will keep batch normalization in `fp32`.
The following code block shows a config file for Apex AMP.
```python
from colossalai.engine import AMP_TYPE
from colossalai.amp import AMP_TYPE
fp16 = dict(
mode=AMP_TYPE.APEX,
@ -71,10 +85,10 @@ and pipeline parallelism.
The following conde block show a config file for this mode.
```python
from colossalai.engine import AMP_TYPE
from colossalai.amp import AMP_TYPE
fp16 = dict(
mode=AMP_TYPE.PARALLEL,
mode=AMP_TYPE.NAIVE,
# below are the default values
clip_grad=0,
log_num_zeros_in_grad=False,

View File

@ -3,185 +3,43 @@
Here is a config file example showing how to train a ViT model on the CIFAR10 dataset using Colossal-AI:
```python
# build train_dataset and train_dataloader from this dictionary
# It is not compulsory in Config File, instead, you can input this dictionary as an argument into colossalai.initialize()
train_data = dict(
# dictionary for building Dataset
dataset=dict(
# the type CIFAR10Dataset has to be registered
type='CIFAR10Dataset',
root='/path/to/data',
# transform pipeline
transform_pipeline=[
dict(type='Resize', size=IMG_SIZE),
dict(type='RandomCrop', size=IMG_SIZE, padding=4),
dict(type='RandomHorizontalFlip'),
dict(type='ToTensor'),
dict(type='Normalize',
mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010]),
]
),
# dictionary for building Dataloader
dataloader=dict(
batch_size=BATCH_SIZE,
pin_memory=True,
# num_workers=1,
shuffle=True,
)
)
# build test_dataset and test_dataloader from this dictionary
test_data = dict(
dataset=dict(
type='CIFAR10Dataset',
root='/path/to/data',
train=False,
transform_pipeline=[
dict(type='Resize', size=IMG_SIZE),
dict(type='ToTensor'),
dict(type='Normalize',
mean=[0.4914, 0.4822, 0.4465],
std=[0.2023, 0.1994, 0.2010]
),
]
),
dataloader=dict(
batch_size=BATCH_SIZE,
pin_memory=True,
# num_workers=1,
)
)
# compulsory
# build optimizer from this dictionary
optimizer = dict(
# Avaluable types: 'ZeroRedundancyOptimizer_Level_1', 'ZeroRedundancyOptimizer_Level_2', 'ZeroRedundancyOptimizer_Level_3'
# 'Adam', 'Lamb', 'SGD', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'FP16Optimizer'
type='Adam',
lr=0.001,
weight_decay=0
)
# compulsory
# build loss function from this dictionary
loss = dict(
# Avaluable types:
# 'CrossEntropyLoss2D', 'CrossEntropyLoss2p5D', 'CrossEntropyLoss3D'
type='CrossEntropyLoss2D',
)
# compulsory
# build model from this dictionary
model = dict(
# types avaluable: 'PretrainBERT', 'VanillaResNet', 'VisionTransformerFromConfig'
type='VisionTransformerFromConfig',
# each key-value pair above refers to a layer
# input data pass through these layers recursively
tensor_splitting_cfg=dict(
type='ViTInputSplitter2D',
),
embedding_cfg=dict(
type='ViTPatchEmbedding2D',
img_size=IMG_SIZE,
patch_size=PATCH_SIZE,
embed_dim=DIM,
),
token_fusion_cfg=dict(
type='ViTTokenFuser2D',
img_size=IMG_SIZE,
patch_size=PATCH_SIZE,
embed_dim=DIM,
drop_rate=0.1
),
norm_cfg=dict(
type='LayerNorm2D',
normalized_shape=DIM,
eps=1e-6,
),
block_cfg=dict(
# ViTBlock is a submodule
type='ViTBlock',
attention_cfg=dict(
type='ViTSelfAttention2D',
hidden_size=DIM,
num_attention_heads=NUM_ATTENTION_HEADS,
attention_dropout_prob=0.,
hidden_dropout_prob=0.1,
checkpoint=True
),
droppath_cfg=dict(
type='VanillaViTDropPath',
),
mlp_cfg=dict(
type='ViTMLP2D',
in_features=DIM,
dropout_prob=0.1,
mlp_ratio=4,
checkpoint=True
),
norm_cfg=dict(
type='LayerNorm2D',
normalized_shape=DIM,
eps=1e-6,
),
),
head_cfg=dict(
type='ViTHead2D',
hidden_size=DIM,
num_classes=NUM_CLASSES,
),
embed_dim=DIM,
depth=DEPTH,
drop_path_rate=0.,
)
# hooks are built when initializing trainer
# possible hooks: 'BaseHook', 'MetricHook','LoadCheckpointHook'
# 'SaveCheckpointHook','LossHook', 'AccuracyHook', 'Accuracy2DHook'
# 'LogMetricByEpochHook', 'TensorboardHook','LogTimingByEpochHook', 'LogMemoryByEpochHook'
hooks = [
dict(type='LogMetricByEpochHook'),
dict(type='LogTimingByEpochHook'),
dict(type='LogMemoryByEpochHook'),
dict(type='Accuracy2DHook'),
dict(type='LossHook'),
# dict(type='TensorboardHook', log_dir='./tfb_logs'),
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
]
# three keys: pipeline, tensor, data
# if data=dict(size=1), which means no data parallelization, then there is no need to define it
# optional
# three keys: pipeline, tensor
# data parallel size is inferred
parallel = dict(
pipeline=dict(size=1),
tensor=dict(size=4, mode='2d'),
)
# not compulsory
# optional
# pipeline or no pipeline schedule
fp16 = dict(
mode=AMP_TYPE.PARALLEL,
mode=AMP_TYPE.NAIVE,
initial_scale=2 ** 8
)
# not compulsory
# build learning rate scheduler
lr_scheduler = dict(
type='LinearWarmupLR',
warmup_epochs=5
)
# optional
# if you are using complex gradient handling
# otherwise, you do not need this in your config file
# default gradient_handlers = None
gradient_handlers = [dict(type='MyHandler', arg1=1, arg=2), ...]
schedule = dict(
num_microbatches=8
)
# optional
# specific gradient accumulation size
# if your batch size is not large enough
gradient_accumulation = <int>
# training stopping criterion
# you can give num_steps or num_epochs
num_epochs = 60
# optional
# add gradient clipping to your engine
# this config is not compatible with zero and AMP_TYPE.NAIVE
# but works with AMP_TYPE.TORCH and AMP_TYPE.APEX
# defautl clip_grad_norm = 0.0
clip_grad_norm = <float>
# optional
# cudnn setting
# default is like below
cudnn_benchmark = False,
cudnn_deterministic=True,
# config logging path
logging = dict(
root_path='./logs'
)
```

View File

@ -20,6 +20,10 @@ pip install .
Install and enable CUDA kernel fusion (compulsory installation when using fused optimizer)
```
```shell
pip install -v --no-cache-dir --global-option="--cuda_ext" .
# install with editable enabled
pip install -v --no-cache-dir --global-option="--cuda_ext" -e .
```

View File

@ -4,38 +4,78 @@
We support multiple parallelization in our library.
Hybrid parallelism in our codebase, namely data parallelism, pipeline parallelism and tensor parallelism (
1D, 2D, 2.5D, 3D). You can initialize the corresponding process group by setting `parallel` in our config. The parallel
configuration can be easily deployed by a dictionary in configuration file. The configuration dictionary must obey the
following format. Data parallel size will be inferred automatically based on your inputs to pipeline parallelism and
tensor parallelism.
Hybrid parallelism in our codebase refers to namely the combination of data parallelism, pipeline parallelism
and tensor parallelism (1D, 2D, 2.5D, 3D). Each parallelism requires different network topology and thus
different initializers for distributed process group. You can initialize the corresponding process group by
setting `parallel` in our config. The parallel configuration can be easily deployed by a dictionary in
configuration file. The configuration dictionary must obey the following format. Data parallel size will be
inferred automatically based on your inputs to pipeline parallelism and tensor parallelism. The distributed
environment will set up by `colossalai.launch`.
```python
# sampler format
parallel = dict(
pipeline=dict("size": int),
tensor=dict("size": int, "mode": '1d' or '2d' or '2.5d' or '3d', "kwargs": Any)
)
# this is ok
parallel = dict(
pipeline=dict(size=2),
tensor=dict(size=4, mode='2d')
)
# this is ok
parallel = dict(
pipeline=2,
tensor=dict(size=4, mode='2d')
)
# this is not ok
# as you need to specify the mode for tensor parallelism
parallel = dict(
pipeline=2,
tensor=4
)
# this is ok as well as tensor will be default to size 1
# and mode None
parallel = dict(
pipeline=2
)
# this is ok as well as pipeline will default to size 1
parallel = dict(
tensor=dict(size=4, mode='2d')
)
```
The name of the dictionary variable should be **parallel**. All the arguments even **parallel** itself are optional and
data, pipeline, tensor parallel size will be set to defaulted value 1. The value of data, pipeline and tensor can be a
int representing the size of specific parallel dimension or a dictionary with a key called "size". The key "mode"
represents the way of tensor parallelism.
represents the way of tensor parallelism.
**You can choose to not have 'parallel' in your configuration and both pipelineand tensor will default to size 1.**
## Data Parallel
Data parallel is the most common way to distribute your training task by splitting data into several shards and train on
a single shard on each device. The configuration for data parallel is detected automatically and set for you. You do not
have to explicitly set them in your configurations. When data parallel size is larger than 1, Colossal-AI automatically
adds the distributed data sampler to the dataloader to shard the dataset.
have to explicitly set them in your configurations. There are two ways to handle the all-reduce in data parallel in Colossal-AI.
1. If you specify gradient handlers, gradients will be all-reduced according to the gradient handlers
2. Otherwise, PyTorch DistributedDataParallel will be used
In most cases, you will be using the second mode unless you have complex handling of the gradients.
## 1D, 2D, 2.5D and 3D Parallel
To enable hybrid parallelism, we provide an array of tensor parallelism. We provide the list of papers which match each
tensor parallel method. These parallel modes need to work with the distributed layers provided by Colossal-AI.
-
1D: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053)
- 1D: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053)
- 2D: [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/abs/2104.05343)
2D parallel relies on the SUMMA matrix multiplication algorithm and splits the input data, model weights and layer
@ -55,158 +95,134 @@ tensor parallel method. These parallel modes need to work with the distributed l
```python
# 1D parallel
parallel = dict(
pipeline=dict(size=1), # number of pipeline stages
tensor=dict(size=4, mode='1d')
)
# 2D parallel
parallel = dict(
pipeline=dict(size=1), # number of pipeline stages
tensor=dict(size=4, mode='2d')
)
# 2.5D parallel
parallel = dict(
pipeline=dict(size=1), # number of pipeline stages
tensor=dict(size=8, mode='2.5d', depth=2)
)
# 3D parallel
parallel = dict(
pipeline=dict(size=1), # number of pipeline stages
tensor=dict(size=8, mode='3d')
)
```
Once you specify the tensor parallel mode in your configuration, you can proceed to use its corresponding distributed
operator. For example, if you mode is '2d', you can use `colossalai.nn.Linear2D` in you model construction.
## Pipeline Parallel (experimental)
Pipeline parallelism is to split the model into several partitions by layer. For example, let's assume we have a simple
model which consists of two linear layer. We have two GPUs, and we can allocate the first linear layer to the first GPU
and the second layer to the second GPU. This example of course wastes the computing resources and is only to demonstrate
the idea of pipeline parallelism.
As PyTorch is based on dynamic computation graph, the computation flow is not known until execution. To support pipeline
parallelism in PyTorch, you may need to add one more attribute, `layers_cfg` in your model class which tells Colossal-AI
the sequence of execution. One example you can refer is `colossalai.nn.model.VanillaResNet`.
```python
from colossalai.nn import BaseModel
import torch
class VanillaResNet(BaseModel):
def __init__(
self,
num_cls: int,
block_type: str,
layers: List[int],
norm_layer_type: str = 'BatchNorm2d',
in_channels: int = 3,
groups: int = 1,
width_per_group: int = 64,
zero_init_residual: bool = False,
replace_stride_with_dilation: Optional[List[bool]] = None,
dilations=(1, 1, 1, 1)
) -> None:
super().__init__()
... # some model params
self.layers_cfg = [
# conv1
dict(type='Conv2d',
in_channels=in_channels,
out_channels=self.inplanes,
kernel_size=7,
stride=2,
padding=3,
bias=False),
# bn1
dict(
type=norm_layer_type,
num_features=self.inplanes
),
# relu
dict(
type='ReLU',
inplace=True
),
# maxpool
dict(
type='MaxPool2d',
kernel_size=3,
stride=2,
padding=1
),
# layer 1
dict(
inplanes=self.inplanes,
planes=64,
blocks=self.blocks[0],
dilation=self.dilations[0],
**self.reslayer_common_cfg
),
# layer 2
dict(
inplanes=64 * self.block_expansion,
planes=128,
blocks=self.blocks[1],
stride=2,
dilate=replace_stride_with_dilation[0],
dilation=self.dilations[1],
**self.reslayer_common_cfg
),
# layer 3
dict(
inplanes=128 * self.block_expansion,
planes=256,
blocks=layers[2],
stride=2,
dilate=replace_stride_with_dilation[1],
dilation=self.dilations[2],
**self.reslayer_common_cfg
),
# layer 4
dict(
inplanes=256 * self.block_expansion,
planes=512,
blocks=layers[3], stride=2,
dilate=replace_stride_with_dilation[2],
dilation=self.dilations[3],
**self.reslayer_common_cfg
),
# avg pool
dict(
type='AdaptiveAvgPool2d',
output_size=(1, 1)
),
# flatten
dict(
type='LambdaWrapper',
func=lambda mod, x: torch.flatten(x, 1)
),
# linear
dict(
type='Linear',
in_features=512 * self.block_expansion,
out_features=num_cls
)
]
```
and the second layer to the second GPU.
You can set the number of pipeline stages in your configuration file. When pipeline size is larger than 1, Colossal-AI
will automatically creates the pipeline schedule which defines the forward and backward step. You can specify how many
microbatches to run in each step in the `schedule` configuration.
will automatically creates the pipeline schedule which defines the forward and backward step.
```python
parallel = dict(
pipeline=dict(size=1), # number of pipeline stages
tensor=dict(size=1, mode=None)
pipeline=dict(size=4), # number of pipeline stages
)
```
As PyTorch is based on dynamic computation graph, the computation flow is not known until execution. To support pipeline parallelism, you have the following two ways to split your model,
1. Split your model directly. Below is an exmaple of resnet split into two pipeline stages.
```python
from torchvision.models import resnet18
from colossalai.core import global_context as gpc
model = resnet18(num_classes=10)
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
model = nn.Sequential(
model.conv1,
model.bn1,
model.relu,
model.maxpool,
model.layer1,
model.layer2
)
elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1:
from functools import partial
class Flatten(nn.Module):
def forward(self, x):
return torch.flatten(x, 1)
model = nn.Sequential(
model.layer3,
model.layer4,
model.avgpool,
Flatten(),
model.fc
)
```
2. Make sure your model inherit `colossalai.nn.model.ModelFromConfig` and registered into the
`MODELS` registry. Define the `self.layers_cfg` attribute.
Pass in a dict/Config object which specifies the parameters of your model.
Use `colossalai.builder.pipeline.PipelineModelInitializer` to partition the layers.
```python
from colossalai.builder import PipelineModelInitializer
from colossalai.nn.model import ModelFromConfig
from colossalai.registry import MODELS
@MODELS.register_module
class MyModel(ModelFromConfig):
def __init__(self, arg1, arg2, ...):
...
self.layers_cfg = [
dict(type='Linear', in_features=3, out_features=512),
dict(type='Linear', in_features=512, out_features=512),
...
]
model_cfg = dict(
type='MyModel',
arg1=1,
arg2=2
...
)
schedule = dict(
num_microbatches = 4 # set the number of microbatches per step
)
initializer = PipelineModelInitializer(model_cfg, num_chunks=1)
model = initializer.initialize()
```
When your model is split into partitions, you can use PipelineSchedule to execute training.
```python
import colossalai
from colossalai.engine.schedule import PipelineSchedule
engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader)
schedule = PipelineSchedule(num_microbatches=4)
# execute a training epoch
data_iter = iter(train_dataloader)
for i in range(len(train_dataloader)):
output, label, loss = schedule.forward_backward_step(engine,
data_iter,
forward_only=False,
)
```
This feature is still in development and is only experimental for now.

View File

@ -1,6 +1,6 @@
# Build your engine & Customize your trainer
# Colossal-AI Engine & Customize Your Trainer
## Build your engine
## Colossal-AI engine
To better understand how `Engine` class works, let's start from the conception of the process function in common
engines. The process function usually controls the behavior over a batch of a dataset, `Engine` class just controls the
@ -16,15 +16,7 @@ def process_function(dataloader, model, criterion, optim):
optim.setp()
```
In `ignite.engine` or `keras.engine`, the process function is always provided by users. However, it is tricky for users
to write their own process functions for pipeline parallelism. Aiming at offering accessible hybrid parallelism for
users, we provide the powerful `Engine` class. This class enables pipeline parallelism and offers
one-forward-one-backward non-interleaving strategy. Also, you can use pre-defined learning rate scheduler in
the `Engine` class to adjust learning rate during training.
In order to build your engine, just set variables `model`, `criterion`, `optimizer`, `lr_scheduler` and `schedule`. The
following code block provides an example. **The engine is automatically created from the config file for you if you
start with `colossalai.initialize`.**
The engine class is a high-level wrapper of these frequently-used functions while preserving the PyTorch-like function signature and integrating with our features.
```python
import torch
@ -32,18 +24,25 @@ import torch.nn as nn
import torchvision.models as models
import colossalai
from colossalai.engine import Engine
from torchvision.datasets import CIFAR10
model = models.resnet18()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
schedule = colossalai.engine.NonPipelineSchedule()
MyEngine = Engine(
model=model,
criterion=criterion,
optimizer=optimizer,
step_schedule=schedule
)
dataset = CIFAR10(...)
dataloader = colossalai.utils.get_dataloader(dataset)
engine, dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, dataloader)
# exmaple of a training iteratio
for img, label in dataloader:
engine.zero_grad()
output = engine(img)
loss = engine.criterion(output, label)
engine.backward(loss)
engine.step()
```
More information regarding the class can be found in the API references.
@ -54,14 +53,14 @@ More information regarding the class can be found in the API references.
To learn how to customize a trainer which meets your needs, let's first give a look at the `Trainer` class. We highly
recommend that you read *Get Started*
section and *Build your engine* first.
section and *Colossal-AI engine* first.
The `Trainer` class enables researchers and engineers to use our system more conveniently. Instead of having to write
your own scripts, you can simply construct your own trainer by calling the `Trainer` class, just like what we did in the
following code block.
```python
MyTrainer = Trainer(my_engine)
trainer = Trainer(engine)
```
After that, you can use the `fit` method to train or evaluate your model. In order to make our `Trainer` class even more
@ -71,26 +70,55 @@ class allows you to execute your hook functions at specified time. We have alrea
as listed below. What you need to do is just picking the right ones which suit your needs. Detailed descriptions of the
class can be found in the API references.
```python
hooks = [
dict(type='LogMetricByEpochHook'),
dict(type='LogTimingByEpochHook'),
dict(type='LogMemoryByEpochHook'),
dict(type='AccuracyHook'),
dict(type='LossHook'),
dict(type='TensorboardHook', log_dir='./tfb_logs'),
dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
]
```
These hook functions will record metrics, elapsed time and memory usage and write them to log after each epoch. Besides,
they print the current loss and accuracy to let users monitor the performance of the model.
```python
import colossalai
from colossalai.trainer import hooks, Trainer
from colossalai.utils import MultiTimer
from colossalai.logging import get_dist_logger
... = colossalai.initialize(...)
timer = MultiTimer()
logger = get_dist_logger()
# if you want to save log to file
logger.log_to_file('./logs/')
trainer = Trainer(
engine=engine,
timer=timer,
logger=logger
)
hook_list = [
hooks.LossHook(),
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
hooks.AccuracyHook(),
hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]),
hooks.LogMetricByEpochHook(logger),
hooks.LogMemoryByEpochHook(logger),
hooks.LogTimingByEpochHook(timer, logger),
hooks.SaveCheckpointHook(checkpoint_dir='./ckpt')
]
trainer.fit(
train_dataloader=train_dataloader,
epochs=NUM_EPOCHS,
test_dataloader=test_dataloader,
test_interval=1,
hooks=hook_list,
display_progress=True
)
```
### Hook
If you have your specific needs, feel free to extend our `BaseHook` class to add your own functions, or our `MetricHook`
class to write a metric collector. These hook functions can be called at twelve timing in the trainer's life cycle.
class to write a metric collector. These hook functions can be called at different stage in the trainer's life cycle.
Besides, you can define the priorities of all hooks to arrange the execution order of them. More information can be
found in the API references.

View File

@ -1,7 +1,7 @@
# Zero Redundancy optimizer and zero offload
The Zero Redundancy Optimizer (ZeRO) removes the memory redundancies across data-parallel processes by partitioning three
model states (optimizer states, gradients, and parameters) across data-parallel processes instead of replicating them.
model states (optimizer states, gradients, and parameters) instead of replicating them.
By doing so, memory efficiency is boosted drastically compared to classic data parallelism while the computational granularity
and communication efficiency are retained.
@ -14,30 +14,26 @@ partition them during the forward and backward passes.
## Getting Started with ZeRO
If you are training models with Colossal-AI, enabling ZeRO-3 offload is as simple as enabling it in your Colossal-AI configuration!
If you are training models with Colossal-AI, enabling ZeRO DP and Offloading is easy by addding several lines in your configuration file. We support configration for level 2 and 3. You have use [PyTorch native implementation](https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html) for level 1 optimizer.
Below are a few examples of ZeRO-3 configurations.
### Example of ZeRO-3 Configurations
Here we use `Adam` as the initial optimizer.
1. Use ZeRO to partition the optimizer states (level 1), gradients (level 2), and parameters (level 3).
1. Use ZeRO to partition the optimizer states, gradients (level 2), and parameters (level 3).
```python
optimizer = dict(
type='Adam',
lr=0.001,
weight_decay=0
)
zero = dict(
type='ZeroRedundancyOptimizer_Level_3',
level=3,
dynamic_loss_scale=True,
clip_grad=1.0
)
```
2. Additionally offload the optimizer states and computations to the CPU.
```python
zero = dict(
level=3,
offload_optimizer_config=dict(
device='cpu',
pin_memory=True,
@ -49,6 +45,7 @@ Here we use `Adam` as the initial optimizer.
3. Save even more memory by offloading parameters to the CPU memory.
```python
zero = dict(
level=3,
offload_optimizer_config=dict(
device='cpu',
pin_memory=True,
@ -65,6 +62,7 @@ Here we use `Adam` as the initial optimizer.
4. Save even MORE memory by offloading to NVMe (if available on your system):
```python
zero = dict(
level=3,
offload_optimizer_config=dict(
device='nvme',
pin_memory=True,
@ -81,7 +79,7 @@ Here we use `Adam` as the initial optimizer.
)
```
Note that `fp16` is automatically enabled when using ZeRO.
Note that `fp16` is automatically enabled when using ZeRO. This relies on `AMP_TYPE.NAIVE` in Colossal-AI AMP module.
### Training