mirror of https://github.com/hpcaitech/ColossalAI
update markdown docs (english) (#60)
parent
da01c234e1
commit
9a0466534c
59
README.md
59
README.md
|
@ -42,21 +42,56 @@ pip install -v --no-cache-dir --global-option="--cuda_ext" .
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.trainer import Trainer
|
from colossalai.utils import get_dataloader
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
|
|
||||||
engine, train_dataloader, test_dataloader = colossalai.initialize()
|
|
||||||
|
|
||||||
trainer = Trainer(engine=engine,
|
# my_config can be path to config file or a dictionary obj
|
||||||
verbose=True)
|
# 'localhost' is only for single node, you need to specify
|
||||||
trainer.fit(
|
# the node name if using multiple nodes
|
||||||
train_dataloader=train_dataloader,
|
colossalai.launch(
|
||||||
test_dataloader=test_dataloader,
|
config=my_config,
|
||||||
epochs=gpc.config.num_epochs,
|
rank=rank,
|
||||||
hooks_cfg=gpc.config.hooks,
|
world_size=world_size,
|
||||||
display_progress=True,
|
backend='nccl',
|
||||||
test_interval=5
|
port=29500,
|
||||||
|
host='localhost'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# build your model
|
||||||
|
model = ...
|
||||||
|
|
||||||
|
# build you dataset, the dataloader will have distributed data
|
||||||
|
# sampler by default
|
||||||
|
train_dataset = ...
|
||||||
|
train_dataloader = get_dataloader(dataset=dataset,
|
||||||
|
shuffle=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# build your
|
||||||
|
optimizer = ...
|
||||||
|
|
||||||
|
# build your loss function
|
||||||
|
criterion = ...
|
||||||
|
|
||||||
|
# build your lr_scheduler
|
||||||
|
engine, train_dataloader, _, _ = colossalai.initialize(
|
||||||
|
model=model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
criterion=criterion,
|
||||||
|
train_dataloader=train_dataloader
|
||||||
|
)
|
||||||
|
|
||||||
|
# start training
|
||||||
|
engine.train()
|
||||||
|
for epoch in range(NUM_EPOCHS):
|
||||||
|
for data, label in train_dataloader:
|
||||||
|
engine.zero_grad()
|
||||||
|
output = engine(data)
|
||||||
|
loss = engine.criterion(output, label)
|
||||||
|
engine.backward(loss)
|
||||||
|
engine.step()
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Write a Simple 2D Parallel Model
|
### Write a Simple 2D Parallel Model
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from ._base_hook import BaseHook
|
from ._base_hook import BaseHook
|
||||||
from ._checkpoint_hook import SaveCheckpointHook, LoadCheckpointHook
|
from ._checkpoint_hook import SaveCheckpointHook, LoadCheckpointHook
|
||||||
from ._metric_hook import LossHook, Accuracy2DHook, AccuracyHook, MetricHook
|
from ._metric_hook import (LossHook, Accuracy2DHook, AccuracyHook, MetricHook,
|
||||||
|
Accuracy1DHook, Accuracy2p5DHook, Accuracy3DHook)
|
||||||
from ._log_hook import LogMetricByEpochHook, TensorboardHook, LogTimingByEpochHook, LogMemoryByEpochHook
|
from ._log_hook import LogMetricByEpochHook, TensorboardHook, LogTimingByEpochHook, LogMemoryByEpochHook
|
||||||
from ._lr_scheduler_hook import LRSchedulerHook
|
from ._lr_scheduler_hook import LRSchedulerHook
|
||||||
|
|
||||||
|
@ -8,6 +9,7 @@ __all__ = [
|
||||||
'BaseHook', 'MetricHook',
|
'BaseHook', 'MetricHook',
|
||||||
'LoadCheckpointHook', 'SaveCheckpointHook',
|
'LoadCheckpointHook', 'SaveCheckpointHook',
|
||||||
'LossHook', 'AccuracyHook', 'Accuracy2DHook',
|
'LossHook', 'AccuracyHook', 'Accuracy2DHook',
|
||||||
|
'Accuracy1DHook', 'Accuracy2p5DHook', 'Accuracy3DHook',
|
||||||
'LogMetricByEpochHook', 'TensorboardHook', 'LogTimingByEpochHook', 'LogMemoryByEpochHook',
|
'LogMetricByEpochHook', 'TensorboardHook', 'LogTimingByEpochHook', 'LogMemoryByEpochHook',
|
||||||
'LRSchedulerHook'
|
'LRSchedulerHook'
|
||||||
]
|
]
|
||||||
|
|
|
@ -108,7 +108,14 @@ class DataParallelSampler(Sampler):
|
||||||
self.epoch = epoch
|
self.epoch = epoch
|
||||||
|
|
||||||
|
|
||||||
def get_dataloader(dataset, shuffle=False, seed=1024, add_sampler=True, **kwargs):
|
def get_dataloader(dataset,
|
||||||
|
shuffle=False,
|
||||||
|
seed=1024,
|
||||||
|
add_sampler=True,
|
||||||
|
drop_last=False,
|
||||||
|
pin_memory=False,
|
||||||
|
num_workers=0,
|
||||||
|
**kwargs):
|
||||||
'''Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
|
'''Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
|
||||||
|
|
||||||
.. note: when pipeline parallel is enabled, shuffle cannot be True
|
.. note: when pipeline parallel is enabled, shuffle cannot be True
|
||||||
|
@ -141,9 +148,15 @@ def get_dataloader(dataset, shuffle=False, seed=1024, add_sampler=True, **kwargs
|
||||||
return DataLoader(dataset,
|
return DataLoader(dataset,
|
||||||
worker_init_fn=seed_worker,
|
worker_init_fn=seed_worker,
|
||||||
shuffle=shuffle,
|
shuffle=shuffle,
|
||||||
|
drop_last=drop_last,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
num_workers=num_workers,
|
||||||
**_kwargs)
|
**_kwargs)
|
||||||
else:
|
else:
|
||||||
return DataLoader(dataset,
|
return DataLoader(dataset,
|
||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
worker_init_fn=seed_worker,
|
worker_init_fn=seed_worker,
|
||||||
|
drop_last=drop_last,
|
||||||
|
pin_memory=pin_memory,
|
||||||
|
num_workers=num_workers,
|
||||||
**_kwargs)
|
**_kwargs)
|
||||||
|
|
|
@ -103,7 +103,7 @@ class YourGradientHandler(BaseGradientHandler):
|
||||||
Afterwards, you can specify the gradient handler you want to use in your configuration file.
|
Afterwards, you can specify the gradient handler you want to use in your configuration file.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
dist_initializer = [
|
gradient_handlers = [
|
||||||
dict(type='YourGradientHandler'),
|
dict(type='YourGradientHandler'),
|
||||||
]
|
]
|
||||||
```
|
```
|
||||||
|
@ -112,5 +112,4 @@ dist_initializer = [
|
||||||
|
|
||||||
Schedule entails how to execute a forward and backward pass. Currently, Colossal-AI provides pipeline and non-pipeline
|
Schedule entails how to execute a forward and backward pass. Currently, Colossal-AI provides pipeline and non-pipeline
|
||||||
schedules. If you want to modify how the forward and backward passes are executed, you can
|
schedules. If you want to modify how the forward and backward passes are executed, you can
|
||||||
inherit `colossalai.engine.BaseSchedule` and implement your idea. You can also add your schedule to the engine before
|
inherit `colossalai.engine.schedule.BaseSchedule` and implement the `forward_back_step` function.
|
||||||
training.
|
|
36
docs/amp.md
36
docs/amp.md
|
@ -3,17 +3,31 @@
|
||||||
In Colossal-AI, we have incorporated different implementations of mixed precision training:
|
In Colossal-AI, we have incorporated different implementations of mixed precision training:
|
||||||
1. torch.cuda.amp
|
1. torch.cuda.amp
|
||||||
2. apex.amp
|
2. apex.amp
|
||||||
3. tensor-parallel amp
|
3. naive amp
|
||||||
|
|
||||||
The first two rely on the original implementation of [PyTorch](https://pytorch.org/docs/stable/amp.html)
|
The first two rely on the original implementation of [PyTorch](https://pytorch.org/docs/stable/amp.html)
|
||||||
(version 1.6 and above) and [Nvidia Apex](https://github.com/NVIDIA/apex). However, these two methods are not compatible
|
(version 1.6 and above) and [Nvidia Apex](https://github.com/NVIDIA/apex). The last mehtod is simialr to Apex O2 level.
|
||||||
with tensor parallelism. This is because that tensors are split across devices in tensor parallelism, thus, it is required
|
|
||||||
to communicate among different processes to check if `inf` or `nan` occurs in the whole model weights. For the mixed
|
Among these methods, apex.amp is not compatible with tensor parallelism. This is because that tensors are split across devices
|
||||||
precision training with tensor parallelism, we adapted this feature from [Megatron-LM](https://github.com/NVIDIA/Megatron-LM).
|
in tensor parallelism, thus, it is required to communicate among different processes to check if `inf` or `nan` occurs in the
|
||||||
|
whole model weights. **We modified the torch amp implementation so that it is compatible with tensor parallelism now.**
|
||||||
|
|
||||||
To use mixed precision training, you can easily specify the `fp16` field in the config file to be True. Currently, PyTorch and
|
To use mixed precision training, you can easily specify the `fp16` field in the config file to be True. Currently, PyTorch and
|
||||||
Apex amp cannot be guaranteed to work with tensor and pipeline parallelism, thus, only the last one is recommended if you
|
Apex amp cannot be guaranteed to work with tensor and pipeline parallelism. We recommend you to use torch amp as it generally
|
||||||
are using hybrid parallelism.
|
gives better accuracy than naive amp.
|
||||||
|
|
||||||
|
The AMP module is designed to be completely modular and can be used independently from other colossalai modules.
|
||||||
|
If you wish to only use amp in your code base without `colossalai.initialize`, you can use `colossalai.amp.convert_to_amp`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from colossalai.amp import AMP_TYPE
|
||||||
|
|
||||||
|
# exmaple of using torch amp
|
||||||
|
model, optimizer, criterion = colossalai.amp.convert_to_amp(model,
|
||||||
|
optimizer,
|
||||||
|
criterion,
|
||||||
|
AMP_TYPE.TORCH)
|
||||||
|
```
|
||||||
|
|
||||||
## PyTorch AMP
|
## PyTorch AMP
|
||||||
|
|
||||||
|
@ -21,7 +35,7 @@ PyTorch provides mixed precision training in version 1.6 and above. It provides
|
||||||
while keeping some operations such as reductions in `fp32`. You can configure the gradient scaler in the config file.
|
while keeping some operations such as reductions in `fp32`. You can configure the gradient scaler in the config file.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from colossalai.engine import AMP_TYPE
|
from colossalai.amp import AMP_TYPE
|
||||||
|
|
||||||
fp16=dict(
|
fp16=dict(
|
||||||
mode=AMP_TYPE.TORCH,
|
mode=AMP_TYPE.TORCH,
|
||||||
|
@ -43,7 +57,7 @@ will keep batch normalization in `fp32`.
|
||||||
The following code block shows a config file for Apex AMP.
|
The following code block shows a config file for Apex AMP.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from colossalai.engine import AMP_TYPE
|
from colossalai.amp import AMP_TYPE
|
||||||
|
|
||||||
fp16 = dict(
|
fp16 = dict(
|
||||||
mode=AMP_TYPE.APEX,
|
mode=AMP_TYPE.APEX,
|
||||||
|
@ -71,10 +85,10 @@ and pipeline parallelism.
|
||||||
The following conde block show a config file for this mode.
|
The following conde block show a config file for this mode.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from colossalai.engine import AMP_TYPE
|
from colossalai.amp import AMP_TYPE
|
||||||
|
|
||||||
fp16 = dict(
|
fp16 = dict(
|
||||||
mode=AMP_TYPE.PARALLEL,
|
mode=AMP_TYPE.NAIVE,
|
||||||
# below are the default values
|
# below are the default values
|
||||||
clip_grad=0,
|
clip_grad=0,
|
||||||
log_num_zeros_in_grad=False,
|
log_num_zeros_in_grad=False,
|
||||||
|
|
194
docs/config.md
194
docs/config.md
|
@ -3,185 +3,43 @@
|
||||||
Here is a config file example showing how to train a ViT model on the CIFAR10 dataset using Colossal-AI:
|
Here is a config file example showing how to train a ViT model on the CIFAR10 dataset using Colossal-AI:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# build train_dataset and train_dataloader from this dictionary
|
# optional
|
||||||
# It is not compulsory in Config File, instead, you can input this dictionary as an argument into colossalai.initialize()
|
# three keys: pipeline, tensor
|
||||||
train_data = dict(
|
# data parallel size is inferred
|
||||||
# dictionary for building Dataset
|
|
||||||
dataset=dict(
|
|
||||||
# the type CIFAR10Dataset has to be registered
|
|
||||||
type='CIFAR10Dataset',
|
|
||||||
root='/path/to/data',
|
|
||||||
# transform pipeline
|
|
||||||
transform_pipeline=[
|
|
||||||
dict(type='Resize', size=IMG_SIZE),
|
|
||||||
dict(type='RandomCrop', size=IMG_SIZE, padding=4),
|
|
||||||
dict(type='RandomHorizontalFlip'),
|
|
||||||
dict(type='ToTensor'),
|
|
||||||
dict(type='Normalize',
|
|
||||||
mean=[0.4914, 0.4822, 0.4465],
|
|
||||||
std=[0.2023, 0.1994, 0.2010]),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
# dictionary for building Dataloader
|
|
||||||
dataloader=dict(
|
|
||||||
batch_size=BATCH_SIZE,
|
|
||||||
pin_memory=True,
|
|
||||||
# num_workers=1,
|
|
||||||
shuffle=True,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# build test_dataset and test_dataloader from this dictionary
|
|
||||||
test_data = dict(
|
|
||||||
dataset=dict(
|
|
||||||
type='CIFAR10Dataset',
|
|
||||||
root='/path/to/data',
|
|
||||||
train=False,
|
|
||||||
transform_pipeline=[
|
|
||||||
dict(type='Resize', size=IMG_SIZE),
|
|
||||||
dict(type='ToTensor'),
|
|
||||||
dict(type='Normalize',
|
|
||||||
mean=[0.4914, 0.4822, 0.4465],
|
|
||||||
std=[0.2023, 0.1994, 0.2010]
|
|
||||||
),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
dataloader=dict(
|
|
||||||
batch_size=BATCH_SIZE,
|
|
||||||
pin_memory=True,
|
|
||||||
# num_workers=1,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# compulsory
|
|
||||||
# build optimizer from this dictionary
|
|
||||||
optimizer = dict(
|
|
||||||
# Avaluable types: 'ZeroRedundancyOptimizer_Level_1', 'ZeroRedundancyOptimizer_Level_2', 'ZeroRedundancyOptimizer_Level_3'
|
|
||||||
# 'Adam', 'Lamb', 'SGD', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'FP16Optimizer'
|
|
||||||
type='Adam',
|
|
||||||
lr=0.001,
|
|
||||||
weight_decay=0
|
|
||||||
)
|
|
||||||
|
|
||||||
# compulsory
|
|
||||||
# build loss function from this dictionary
|
|
||||||
loss = dict(
|
|
||||||
# Avaluable types:
|
|
||||||
# 'CrossEntropyLoss2D', 'CrossEntropyLoss2p5D', 'CrossEntropyLoss3D'
|
|
||||||
type='CrossEntropyLoss2D',
|
|
||||||
)
|
|
||||||
|
|
||||||
# compulsory
|
|
||||||
# build model from this dictionary
|
|
||||||
model = dict(
|
|
||||||
# types avaluable: 'PretrainBERT', 'VanillaResNet', 'VisionTransformerFromConfig'
|
|
||||||
type='VisionTransformerFromConfig',
|
|
||||||
# each key-value pair above refers to a layer
|
|
||||||
# input data pass through these layers recursively
|
|
||||||
tensor_splitting_cfg=dict(
|
|
||||||
type='ViTInputSplitter2D',
|
|
||||||
),
|
|
||||||
embedding_cfg=dict(
|
|
||||||
type='ViTPatchEmbedding2D',
|
|
||||||
img_size=IMG_SIZE,
|
|
||||||
patch_size=PATCH_SIZE,
|
|
||||||
embed_dim=DIM,
|
|
||||||
),
|
|
||||||
token_fusion_cfg=dict(
|
|
||||||
type='ViTTokenFuser2D',
|
|
||||||
img_size=IMG_SIZE,
|
|
||||||
patch_size=PATCH_SIZE,
|
|
||||||
embed_dim=DIM,
|
|
||||||
drop_rate=0.1
|
|
||||||
),
|
|
||||||
norm_cfg=dict(
|
|
||||||
type='LayerNorm2D',
|
|
||||||
normalized_shape=DIM,
|
|
||||||
eps=1e-6,
|
|
||||||
),
|
|
||||||
block_cfg=dict(
|
|
||||||
# ViTBlock is a submodule
|
|
||||||
type='ViTBlock',
|
|
||||||
attention_cfg=dict(
|
|
||||||
type='ViTSelfAttention2D',
|
|
||||||
hidden_size=DIM,
|
|
||||||
num_attention_heads=NUM_ATTENTION_HEADS,
|
|
||||||
attention_dropout_prob=0.,
|
|
||||||
hidden_dropout_prob=0.1,
|
|
||||||
checkpoint=True
|
|
||||||
),
|
|
||||||
droppath_cfg=dict(
|
|
||||||
type='VanillaViTDropPath',
|
|
||||||
),
|
|
||||||
mlp_cfg=dict(
|
|
||||||
type='ViTMLP2D',
|
|
||||||
in_features=DIM,
|
|
||||||
dropout_prob=0.1,
|
|
||||||
mlp_ratio=4,
|
|
||||||
checkpoint=True
|
|
||||||
),
|
|
||||||
norm_cfg=dict(
|
|
||||||
type='LayerNorm2D',
|
|
||||||
normalized_shape=DIM,
|
|
||||||
eps=1e-6,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
head_cfg=dict(
|
|
||||||
type='ViTHead2D',
|
|
||||||
hidden_size=DIM,
|
|
||||||
num_classes=NUM_CLASSES,
|
|
||||||
),
|
|
||||||
embed_dim=DIM,
|
|
||||||
depth=DEPTH,
|
|
||||||
drop_path_rate=0.,
|
|
||||||
)
|
|
||||||
|
|
||||||
# hooks are built when initializing trainer
|
|
||||||
# possible hooks: 'BaseHook', 'MetricHook','LoadCheckpointHook'
|
|
||||||
# 'SaveCheckpointHook','LossHook', 'AccuracyHook', 'Accuracy2DHook'
|
|
||||||
# 'LogMetricByEpochHook', 'TensorboardHook','LogTimingByEpochHook', 'LogMemoryByEpochHook'
|
|
||||||
hooks = [
|
|
||||||
dict(type='LogMetricByEpochHook'),
|
|
||||||
dict(type='LogTimingByEpochHook'),
|
|
||||||
dict(type='LogMemoryByEpochHook'),
|
|
||||||
dict(type='Accuracy2DHook'),
|
|
||||||
dict(type='LossHook'),
|
|
||||||
# dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
|
||||||
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
|
||||||
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
|
|
||||||
]
|
|
||||||
|
|
||||||
# three keys: pipeline, tensor, data
|
|
||||||
# if data=dict(size=1), which means no data parallelization, then there is no need to define it
|
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
pipeline=dict(size=1),
|
pipeline=dict(size=1),
|
||||||
tensor=dict(size=4, mode='2d'),
|
tensor=dict(size=4, mode='2d'),
|
||||||
)
|
)
|
||||||
|
|
||||||
# not compulsory
|
# optional
|
||||||
# pipeline or no pipeline schedule
|
# pipeline or no pipeline schedule
|
||||||
fp16 = dict(
|
fp16 = dict(
|
||||||
mode=AMP_TYPE.PARALLEL,
|
mode=AMP_TYPE.NAIVE,
|
||||||
initial_scale=2 ** 8
|
initial_scale=2 ** 8
|
||||||
)
|
)
|
||||||
|
|
||||||
# not compulsory
|
# optional
|
||||||
# build learning rate scheduler
|
# if you are using complex gradient handling
|
||||||
lr_scheduler = dict(
|
# otherwise, you do not need this in your config file
|
||||||
type='LinearWarmupLR',
|
# default gradient_handlers = None
|
||||||
warmup_epochs=5
|
gradient_handlers = [dict(type='MyHandler', arg1=1, arg=2), ...]
|
||||||
)
|
|
||||||
|
|
||||||
schedule = dict(
|
# optional
|
||||||
num_microbatches=8
|
# specific gradient accumulation size
|
||||||
)
|
# if your batch size is not large enough
|
||||||
|
gradient_accumulation = <int>
|
||||||
|
|
||||||
# training stopping criterion
|
# optional
|
||||||
# you can give num_steps or num_epochs
|
# add gradient clipping to your engine
|
||||||
num_epochs = 60
|
# this config is not compatible with zero and AMP_TYPE.NAIVE
|
||||||
|
# but works with AMP_TYPE.TORCH and AMP_TYPE.APEX
|
||||||
|
# defautl clip_grad_norm = 0.0
|
||||||
|
clip_grad_norm = <float>
|
||||||
|
|
||||||
|
# optional
|
||||||
|
# cudnn setting
|
||||||
|
# default is like below
|
||||||
|
cudnn_benchmark = False,
|
||||||
|
cudnn_deterministic=True,
|
||||||
|
|
||||||
# config logging path
|
|
||||||
logging = dict(
|
|
||||||
root_path='./logs'
|
|
||||||
)
|
|
||||||
```
|
```
|
|
@ -20,6 +20,10 @@ pip install .
|
||||||
|
|
||||||
Install and enable CUDA kernel fusion (compulsory installation when using fused optimizer)
|
Install and enable CUDA kernel fusion (compulsory installation when using fused optimizer)
|
||||||
|
|
||||||
```
|
```shell
|
||||||
pip install -v --no-cache-dir --global-option="--cuda_ext" .
|
pip install -v --no-cache-dir --global-option="--cuda_ext" .
|
||||||
|
|
||||||
|
# install with editable enabled
|
||||||
|
pip install -v --no-cache-dir --global-option="--cuda_ext" -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -4,17 +4,51 @@
|
||||||
|
|
||||||
We support multiple parallelization in our library.
|
We support multiple parallelization in our library.
|
||||||
|
|
||||||
Hybrid parallelism in our codebase, namely data parallelism, pipeline parallelism and tensor parallelism (
|
Hybrid parallelism in our codebase refers to namely the combination of data parallelism, pipeline parallelism
|
||||||
1D, 2D, 2.5D, 3D). You can initialize the corresponding process group by setting `parallel` in our config. The parallel
|
and tensor parallelism (1D, 2D, 2.5D, 3D). Each parallelism requires different network topology and thus
|
||||||
configuration can be easily deployed by a dictionary in configuration file. The configuration dictionary must obey the
|
different initializers for distributed process group. You can initialize the corresponding process group by
|
||||||
following format. Data parallel size will be inferred automatically based on your inputs to pipeline parallelism and
|
setting `parallel` in our config. The parallel configuration can be easily deployed by a dictionary in
|
||||||
tensor parallelism.
|
configuration file. The configuration dictionary must obey the following format. Data parallel size will be
|
||||||
|
inferred automatically based on your inputs to pipeline parallelism and tensor parallelism. The distributed
|
||||||
|
environment will set up by `colossalai.launch`.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
# sampler format
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
pipeline=dict("size": int),
|
pipeline=dict("size": int),
|
||||||
tensor=dict("size": int, "mode": '1d' or '2d' or '2.5d' or '3d', "kwargs": Any)
|
tensor=dict("size": int, "mode": '1d' or '2d' or '2.5d' or '3d', "kwargs": Any)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# this is ok
|
||||||
|
parallel = dict(
|
||||||
|
pipeline=dict(size=2),
|
||||||
|
tensor=dict(size=4, mode='2d')
|
||||||
|
)
|
||||||
|
|
||||||
|
# this is ok
|
||||||
|
parallel = dict(
|
||||||
|
pipeline=2,
|
||||||
|
tensor=dict(size=4, mode='2d')
|
||||||
|
)
|
||||||
|
|
||||||
|
# this is not ok
|
||||||
|
# as you need to specify the mode for tensor parallelism
|
||||||
|
parallel = dict(
|
||||||
|
pipeline=2,
|
||||||
|
tensor=4
|
||||||
|
)
|
||||||
|
|
||||||
|
# this is ok as well as tensor will be default to size 1
|
||||||
|
# and mode None
|
||||||
|
parallel = dict(
|
||||||
|
pipeline=2
|
||||||
|
)
|
||||||
|
|
||||||
|
# this is ok as well as pipeline will default to size 1
|
||||||
|
parallel = dict(
|
||||||
|
tensor=dict(size=4, mode='2d')
|
||||||
|
)
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
The name of the dictionary variable should be **parallel**. All the arguments even **parallel** itself are optional and
|
The name of the dictionary variable should be **parallel**. All the arguments even **parallel** itself are optional and
|
||||||
|
@ -22,20 +56,26 @@ data, pipeline, tensor parallel size will be set to defaulted value 1. The value
|
||||||
int representing the size of specific parallel dimension or a dictionary with a key called "size". The key "mode"
|
int representing the size of specific parallel dimension or a dictionary with a key called "size". The key "mode"
|
||||||
represents the way of tensor parallelism.
|
represents the way of tensor parallelism.
|
||||||
|
|
||||||
|
**You can choose to not have 'parallel' in your configuration and both pipelineand tensor will default to size 1.**
|
||||||
|
|
||||||
|
|
||||||
## Data Parallel
|
## Data Parallel
|
||||||
|
|
||||||
Data parallel is the most common way to distribute your training task by splitting data into several shards and train on
|
Data parallel is the most common way to distribute your training task by splitting data into several shards and train on
|
||||||
a single shard on each device. The configuration for data parallel is detected automatically and set for you. You do not
|
a single shard on each device. The configuration for data parallel is detected automatically and set for you. You do not
|
||||||
have to explicitly set them in your configurations. When data parallel size is larger than 1, Colossal-AI automatically
|
have to explicitly set them in your configurations. There are two ways to handle the all-reduce in data parallel in Colossal-AI.
|
||||||
adds the distributed data sampler to the dataloader to shard the dataset.
|
|
||||||
|
1. If you specify gradient handlers, gradients will be all-reduced according to the gradient handlers
|
||||||
|
2. Otherwise, PyTorch DistributedDataParallel will be used
|
||||||
|
|
||||||
|
In most cases, you will be using the second mode unless you have complex handling of the gradients.
|
||||||
|
|
||||||
## 1D, 2D, 2.5D and 3D Parallel
|
## 1D, 2D, 2.5D and 3D Parallel
|
||||||
|
|
||||||
To enable hybrid parallelism, we provide an array of tensor parallelism. We provide the list of papers which match each
|
To enable hybrid parallelism, we provide an array of tensor parallelism. We provide the list of papers which match each
|
||||||
tensor parallel method. These parallel modes need to work with the distributed layers provided by Colossal-AI.
|
tensor parallel method. These parallel modes need to work with the distributed layers provided by Colossal-AI.
|
||||||
|
|
||||||
-
|
- 1D: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053)
|
||||||
1D: [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053)
|
|
||||||
|
|
||||||
- 2D: [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/abs/2104.05343)
|
- 2D: [An Efficient 2D Method for Training Super-Large Deep Learning Models](https://arxiv.org/abs/2104.05343)
|
||||||
2D parallel relies on the SUMMA matrix multiplication algorithm and splits the input data, model weights and layer
|
2D parallel relies on the SUMMA matrix multiplication algorithm and splits the input data, model weights and layer
|
||||||
|
@ -55,158 +95,134 @@ tensor parallel method. These parallel modes need to work with the distributed l
|
||||||
```python
|
```python
|
||||||
# 1D parallel
|
# 1D parallel
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
pipeline=dict(size=1), # number of pipeline stages
|
|
||||||
tensor=dict(size=4, mode='1d')
|
tensor=dict(size=4, mode='1d')
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2D parallel
|
# 2D parallel
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
pipeline=dict(size=1), # number of pipeline stages
|
|
||||||
tensor=dict(size=4, mode='2d')
|
tensor=dict(size=4, mode='2d')
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2.5D parallel
|
# 2.5D parallel
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
pipeline=dict(size=1), # number of pipeline stages
|
|
||||||
tensor=dict(size=8, mode='2.5d', depth=2)
|
tensor=dict(size=8, mode='2.5d', depth=2)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3D parallel
|
# 3D parallel
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
pipeline=dict(size=1), # number of pipeline stages
|
|
||||||
tensor=dict(size=8, mode='3d')
|
tensor=dict(size=8, mode='3d')
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Once you specify the tensor parallel mode in your configuration, you can proceed to use its corresponding distributed
|
||||||
|
operator. For example, if you mode is '2d', you can use `colossalai.nn.Linear2D` in you model construction.
|
||||||
|
|
||||||
|
|
||||||
## Pipeline Parallel (experimental)
|
## Pipeline Parallel (experimental)
|
||||||
|
|
||||||
Pipeline parallelism is to split the model into several partitions by layer. For example, let's assume we have a simple
|
Pipeline parallelism is to split the model into several partitions by layer. For example, let's assume we have a simple
|
||||||
model which consists of two linear layer. We have two GPUs, and we can allocate the first linear layer to the first GPU
|
model which consists of two linear layer. We have two GPUs, and we can allocate the first linear layer to the first GPU
|
||||||
and the second layer to the second GPU. This example of course wastes the computing resources and is only to demonstrate
|
and the second layer to the second GPU.
|
||||||
the idea of pipeline parallelism.
|
|
||||||
|
|
||||||
As PyTorch is based on dynamic computation graph, the computation flow is not known until execution. To support pipeline
|
|
||||||
parallelism in PyTorch, you may need to add one more attribute, `layers_cfg` in your model class which tells Colossal-AI
|
|
||||||
the sequence of execution. One example you can refer is `colossalai.nn.model.VanillaResNet`.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from colossalai.nn import BaseModel
|
|
||||||
import torch
|
|
||||||
|
|
||||||
class VanillaResNet(BaseModel):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_cls: int,
|
|
||||||
block_type: str,
|
|
||||||
layers: List[int],
|
|
||||||
norm_layer_type: str = 'BatchNorm2d',
|
|
||||||
in_channels: int = 3,
|
|
||||||
groups: int = 1,
|
|
||||||
width_per_group: int = 64,
|
|
||||||
zero_init_residual: bool = False,
|
|
||||||
replace_stride_with_dilation: Optional[List[bool]] = None,
|
|
||||||
dilations=(1, 1, 1, 1)
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
... # some model params
|
|
||||||
|
|
||||||
self.layers_cfg = [
|
|
||||||
# conv1
|
|
||||||
dict(type='Conv2d',
|
|
||||||
in_channels=in_channels,
|
|
||||||
out_channels=self.inplanes,
|
|
||||||
kernel_size=7,
|
|
||||||
stride=2,
|
|
||||||
padding=3,
|
|
||||||
bias=False),
|
|
||||||
# bn1
|
|
||||||
dict(
|
|
||||||
type=norm_layer_type,
|
|
||||||
num_features=self.inplanes
|
|
||||||
),
|
|
||||||
# relu
|
|
||||||
dict(
|
|
||||||
type='ReLU',
|
|
||||||
inplace=True
|
|
||||||
),
|
|
||||||
# maxpool
|
|
||||||
dict(
|
|
||||||
type='MaxPool2d',
|
|
||||||
kernel_size=3,
|
|
||||||
stride=2,
|
|
||||||
padding=1
|
|
||||||
),
|
|
||||||
# layer 1
|
|
||||||
dict(
|
|
||||||
inplanes=self.inplanes,
|
|
||||||
planes=64,
|
|
||||||
blocks=self.blocks[0],
|
|
||||||
dilation=self.dilations[0],
|
|
||||||
**self.reslayer_common_cfg
|
|
||||||
),
|
|
||||||
# layer 2
|
|
||||||
dict(
|
|
||||||
inplanes=64 * self.block_expansion,
|
|
||||||
planes=128,
|
|
||||||
blocks=self.blocks[1],
|
|
||||||
stride=2,
|
|
||||||
dilate=replace_stride_with_dilation[0],
|
|
||||||
dilation=self.dilations[1],
|
|
||||||
**self.reslayer_common_cfg
|
|
||||||
),
|
|
||||||
# layer 3
|
|
||||||
dict(
|
|
||||||
inplanes=128 * self.block_expansion,
|
|
||||||
planes=256,
|
|
||||||
blocks=layers[2],
|
|
||||||
stride=2,
|
|
||||||
dilate=replace_stride_with_dilation[1],
|
|
||||||
dilation=self.dilations[2],
|
|
||||||
**self.reslayer_common_cfg
|
|
||||||
),
|
|
||||||
# layer 4
|
|
||||||
dict(
|
|
||||||
inplanes=256 * self.block_expansion,
|
|
||||||
planes=512,
|
|
||||||
blocks=layers[3], stride=2,
|
|
||||||
dilate=replace_stride_with_dilation[2],
|
|
||||||
dilation=self.dilations[3],
|
|
||||||
**self.reslayer_common_cfg
|
|
||||||
),
|
|
||||||
# avg pool
|
|
||||||
dict(
|
|
||||||
type='AdaptiveAvgPool2d',
|
|
||||||
output_size=(1, 1)
|
|
||||||
),
|
|
||||||
# flatten
|
|
||||||
dict(
|
|
||||||
type='LambdaWrapper',
|
|
||||||
func=lambda mod, x: torch.flatten(x, 1)
|
|
||||||
),
|
|
||||||
# linear
|
|
||||||
dict(
|
|
||||||
type='Linear',
|
|
||||||
in_features=512 * self.block_expansion,
|
|
||||||
out_features=num_cls
|
|
||||||
)
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
You can set the number of pipeline stages in your configuration file. When pipeline size is larger than 1, Colossal-AI
|
You can set the number of pipeline stages in your configuration file. When pipeline size is larger than 1, Colossal-AI
|
||||||
will automatically creates the pipeline schedule which defines the forward and backward step. You can specify how many
|
will automatically creates the pipeline schedule which defines the forward and backward step.
|
||||||
microbatches to run in each step in the `schedule` configuration.
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
pipeline=dict(size=1), # number of pipeline stages
|
pipeline=dict(size=4), # number of pipeline stages
|
||||||
tensor=dict(size=1, mode=None)
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
As PyTorch is based on dynamic computation graph, the computation flow is not known until execution. To support pipeline parallelism, you have the following two ways to split your model,
|
||||||
|
|
||||||
|
1. Split your model directly. Below is an exmaple of resnet split into two pipeline stages.
|
||||||
|
```python
|
||||||
|
from torchvision.models import resnet18
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
|
model = resnet18(num_classes=10)
|
||||||
|
|
||||||
|
if gpc.get_local_rank(ParallelMode.PIPELINE) == 0:
|
||||||
|
model = nn.Sequential(
|
||||||
|
model.conv1,
|
||||||
|
model.bn1,
|
||||||
|
model.relu,
|
||||||
|
model.maxpool,
|
||||||
|
model.layer1,
|
||||||
|
model.layer2
|
||||||
|
)
|
||||||
|
elif gpc.get_local_rank(ParallelMode.PIPELINE) == 1:
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
class Flatten(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.flatten(x, 1)
|
||||||
|
|
||||||
|
model = nn.Sequential(
|
||||||
|
model.layer3,
|
||||||
|
model.layer4,
|
||||||
|
model.avgpool,
|
||||||
|
Flatten(),
|
||||||
|
model.fc
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
2. Make sure your model inherit `colossalai.nn.model.ModelFromConfig` and registered into the
|
||||||
|
`MODELS` registry. Define the `self.layers_cfg` attribute.
|
||||||
|
Pass in a dict/Config object which specifies the parameters of your model.
|
||||||
|
Use `colossalai.builder.pipeline.PipelineModelInitializer` to partition the layers.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from colossalai.builder import PipelineModelInitializer
|
||||||
|
from colossalai.nn.model import ModelFromConfig
|
||||||
|
from colossalai.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module
|
||||||
|
class MyModel(ModelFromConfig):
|
||||||
|
|
||||||
|
def __init__(self, arg1, arg2, ...):
|
||||||
|
...
|
||||||
|
self.layers_cfg = [
|
||||||
|
dict(type='Linear', in_features=3, out_features=512),
|
||||||
|
dict(type='Linear', in_features=512, out_features=512),
|
||||||
|
...
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
model_cfg = dict(
|
||||||
|
type='MyModel',
|
||||||
|
arg1=1,
|
||||||
|
arg2=2
|
||||||
|
...
|
||||||
)
|
)
|
||||||
|
|
||||||
schedule = dict(
|
initializer = PipelineModelInitializer(model_cfg, num_chunks=1)
|
||||||
num_microbatches = 4 # set the number of microbatches per step
|
model = initializer.initialize()
|
||||||
)
|
|
||||||
|
```
|
||||||
|
|
||||||
|
When your model is split into partitions, you can use PipelineSchedule to execute training.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import colossalai
|
||||||
|
from colossalai.engine.schedule import PipelineSchedule
|
||||||
|
|
||||||
|
engine, train_dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, train_dataloader)
|
||||||
|
|
||||||
|
schedule = PipelineSchedule(num_microbatches=4)
|
||||||
|
|
||||||
|
# execute a training epoch
|
||||||
|
data_iter = iter(train_dataloader)
|
||||||
|
|
||||||
|
for i in range(len(train_dataloader)):
|
||||||
|
output, label, loss = schedule.forward_backward_step(engine,
|
||||||
|
data_iter,
|
||||||
|
forward_only=False,
|
||||||
|
)
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
This feature is still in development and is only experimental for now.
|
This feature is still in development and is only experimental for now.
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# Build your engine & Customize your trainer
|
# Colossal-AI Engine & Customize Your Trainer
|
||||||
|
|
||||||
## Build your engine
|
## Colossal-AI engine
|
||||||
|
|
||||||
To better understand how `Engine` class works, let's start from the conception of the process function in common
|
To better understand how `Engine` class works, let's start from the conception of the process function in common
|
||||||
engines. The process function usually controls the behavior over a batch of a dataset, `Engine` class just controls the
|
engines. The process function usually controls the behavior over a batch of a dataset, `Engine` class just controls the
|
||||||
|
@ -16,15 +16,7 @@ def process_function(dataloader, model, criterion, optim):
|
||||||
optim.setp()
|
optim.setp()
|
||||||
```
|
```
|
||||||
|
|
||||||
In `ignite.engine` or `keras.engine`, the process function is always provided by users. However, it is tricky for users
|
The engine class is a high-level wrapper of these frequently-used functions while preserving the PyTorch-like function signature and integrating with our features.
|
||||||
to write their own process functions for pipeline parallelism. Aiming at offering accessible hybrid parallelism for
|
|
||||||
users, we provide the powerful `Engine` class. This class enables pipeline parallelism and offers
|
|
||||||
one-forward-one-backward non-interleaving strategy. Also, you can use pre-defined learning rate scheduler in
|
|
||||||
the `Engine` class to adjust learning rate during training.
|
|
||||||
|
|
||||||
In order to build your engine, just set variables `model`, `criterion`, `optimizer`, `lr_scheduler` and `schedule`. The
|
|
||||||
following code block provides an example. **The engine is automatically created from the config file for you if you
|
|
||||||
start with `colossalai.initialize`.**
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
|
@ -32,18 +24,25 @@ import torch.nn as nn
|
||||||
import torchvision.models as models
|
import torchvision.models as models
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.engine import Engine
|
from colossalai.engine import Engine
|
||||||
|
from torchvision.datasets import CIFAR10
|
||||||
|
|
||||||
model = models.resnet18()
|
model = models.resnet18()
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
optimizer = torch.optim.Adam(model.parameters())
|
optimizer = torch.optim.Adam(model.parameters())
|
||||||
schedule = colossalai.engine.NonPipelineSchedule()
|
|
||||||
|
|
||||||
MyEngine = Engine(
|
dataset = CIFAR10(...)
|
||||||
model=model,
|
dataloader = colossalai.utils.get_dataloader(dataset)
|
||||||
criterion=criterion,
|
|
||||||
optimizer=optimizer,
|
engine, dataloader, _, _ = colossalai.initialize(model, optimizer, criterion, dataloader)
|
||||||
step_schedule=schedule
|
|
||||||
)
|
# exmaple of a training iteratio
|
||||||
|
for img, label in dataloader:
|
||||||
|
engine.zero_grad()
|
||||||
|
output = engine(img)
|
||||||
|
loss = engine.criterion(output, label)
|
||||||
|
engine.backward(loss)
|
||||||
|
engine.step()
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
More information regarding the class can be found in the API references.
|
More information regarding the class can be found in the API references.
|
||||||
|
@ -54,14 +53,14 @@ More information regarding the class can be found in the API references.
|
||||||
|
|
||||||
To learn how to customize a trainer which meets your needs, let's first give a look at the `Trainer` class. We highly
|
To learn how to customize a trainer which meets your needs, let's first give a look at the `Trainer` class. We highly
|
||||||
recommend that you read *Get Started*
|
recommend that you read *Get Started*
|
||||||
section and *Build your engine* first.
|
section and *Colossal-AI engine* first.
|
||||||
|
|
||||||
The `Trainer` class enables researchers and engineers to use our system more conveniently. Instead of having to write
|
The `Trainer` class enables researchers and engineers to use our system more conveniently. Instead of having to write
|
||||||
your own scripts, you can simply construct your own trainer by calling the `Trainer` class, just like what we did in the
|
your own scripts, you can simply construct your own trainer by calling the `Trainer` class, just like what we did in the
|
||||||
following code block.
|
following code block.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
MyTrainer = Trainer(my_engine)
|
trainer = Trainer(engine)
|
||||||
```
|
```
|
||||||
|
|
||||||
After that, you can use the `fit` method to train or evaluate your model. In order to make our `Trainer` class even more
|
After that, you can use the `fit` method to train or evaluate your model. In order to make our `Trainer` class even more
|
||||||
|
@ -71,26 +70,55 @@ class allows you to execute your hook functions at specified time. We have alrea
|
||||||
as listed below. What you need to do is just picking the right ones which suit your needs. Detailed descriptions of the
|
as listed below. What you need to do is just picking the right ones which suit your needs. Detailed descriptions of the
|
||||||
class can be found in the API references.
|
class can be found in the API references.
|
||||||
|
|
||||||
```python
|
|
||||||
hooks = [
|
|
||||||
dict(type='LogMetricByEpochHook'),
|
|
||||||
dict(type='LogTimingByEpochHook'),
|
|
||||||
dict(type='LogMemoryByEpochHook'),
|
|
||||||
dict(type='AccuracyHook'),
|
|
||||||
dict(type='LossHook'),
|
|
||||||
dict(type='TensorboardHook', log_dir='./tfb_logs'),
|
|
||||||
dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
|
|
||||||
dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
These hook functions will record metrics, elapsed time and memory usage and write them to log after each epoch. Besides,
|
These hook functions will record metrics, elapsed time and memory usage and write them to log after each epoch. Besides,
|
||||||
they print the current loss and accuracy to let users monitor the performance of the model.
|
they print the current loss and accuracy to let users monitor the performance of the model.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import colossalai
|
||||||
|
from colossalai.trainer import hooks, Trainer
|
||||||
|
from colossalai.utils import MultiTimer
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
... = colossalai.initialize(...)
|
||||||
|
|
||||||
|
timer = MultiTimer()
|
||||||
|
logger = get_dist_logger()
|
||||||
|
|
||||||
|
# if you want to save log to file
|
||||||
|
logger.log_to_file('./logs/')
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
engine=engine,
|
||||||
|
timer=timer,
|
||||||
|
logger=logger
|
||||||
|
)
|
||||||
|
|
||||||
|
hook_list = [
|
||||||
|
hooks.LossHook(),
|
||||||
|
hooks.LRSchedulerHook(lr_scheduler=lr_scheduler, by_epoch=False),
|
||||||
|
hooks.AccuracyHook(),
|
||||||
|
hooks.TensorboardHook(log_dir='./tb_logs', ranks=[0]),
|
||||||
|
hooks.LogMetricByEpochHook(logger),
|
||||||
|
hooks.LogMemoryByEpochHook(logger),
|
||||||
|
hooks.LogTimingByEpochHook(timer, logger),
|
||||||
|
hooks.SaveCheckpointHook(checkpoint_dir='./ckpt')
|
||||||
|
]
|
||||||
|
|
||||||
|
trainer.fit(
|
||||||
|
train_dataloader=train_dataloader,
|
||||||
|
epochs=NUM_EPOCHS,
|
||||||
|
test_dataloader=test_dataloader,
|
||||||
|
test_interval=1,
|
||||||
|
hooks=hook_list,
|
||||||
|
display_progress=True
|
||||||
|
)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
### Hook
|
### Hook
|
||||||
|
|
||||||
If you have your specific needs, feel free to extend our `BaseHook` class to add your own functions, or our `MetricHook`
|
If you have your specific needs, feel free to extend our `BaseHook` class to add your own functions, or our `MetricHook`
|
||||||
class to write a metric collector. These hook functions can be called at twelve timing in the trainer's life cycle.
|
class to write a metric collector. These hook functions can be called at different stage in the trainer's life cycle.
|
||||||
Besides, you can define the priorities of all hooks to arrange the execution order of them. More information can be
|
Besides, you can define the priorities of all hooks to arrange the execution order of them. More information can be
|
||||||
found in the API references.
|
found in the API references.
|
||||||
|
|
||||||
|
|
20
docs/zero.md
20
docs/zero.md
|
@ -1,7 +1,7 @@
|
||||||
# Zero Redundancy optimizer and zero offload
|
# Zero Redundancy optimizer and zero offload
|
||||||
|
|
||||||
The Zero Redundancy Optimizer (ZeRO) removes the memory redundancies across data-parallel processes by partitioning three
|
The Zero Redundancy Optimizer (ZeRO) removes the memory redundancies across data-parallel processes by partitioning three
|
||||||
model states (optimizer states, gradients, and parameters) across data-parallel processes instead of replicating them.
|
model states (optimizer states, gradients, and parameters) instead of replicating them.
|
||||||
By doing so, memory efficiency is boosted drastically compared to classic data parallelism while the computational granularity
|
By doing so, memory efficiency is boosted drastically compared to classic data parallelism while the computational granularity
|
||||||
and communication efficiency are retained.
|
and communication efficiency are retained.
|
||||||
|
|
||||||
|
@ -14,30 +14,26 @@ partition them during the forward and backward passes.
|
||||||
|
|
||||||
## Getting Started with ZeRO
|
## Getting Started with ZeRO
|
||||||
|
|
||||||
If you are training models with Colossal-AI, enabling ZeRO-3 offload is as simple as enabling it in your Colossal-AI configuration!
|
If you are training models with Colossal-AI, enabling ZeRO DP and Offloading is easy by addding several lines in your configuration file. We support configration for level 2 and 3. You have use [PyTorch native implementation](https://pytorch.org/tutorials/recipes/zero_redundancy_optimizer.html) for level 1 optimizer.
|
||||||
Below are a few examples of ZeRO-3 configurations.
|
Below are a few examples of ZeRO-3 configurations.
|
||||||
|
|
||||||
### Example of ZeRO-3 Configurations
|
### Example of ZeRO-3 Configurations
|
||||||
|
|
||||||
Here we use `Adam` as the initial optimizer.
|
Here we use `Adam` as the initial optimizer.
|
||||||
|
|
||||||
1. Use ZeRO to partition the optimizer states (level 1), gradients (level 2), and parameters (level 3).
|
1. Use ZeRO to partition the optimizer states, gradients (level 2), and parameters (level 3).
|
||||||
```python
|
```python
|
||||||
optimizer = dict(
|
|
||||||
type='Adam',
|
|
||||||
lr=0.001,
|
|
||||||
weight_decay=0
|
|
||||||
)
|
|
||||||
|
|
||||||
zero = dict(
|
zero = dict(
|
||||||
type='ZeroRedundancyOptimizer_Level_3',
|
level=3,
|
||||||
dynamic_loss_scale=True,
|
dynamic_loss_scale=True,
|
||||||
clip_grad=1.0
|
clip_grad=1.0
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Additionally offload the optimizer states and computations to the CPU.
|
2. Additionally offload the optimizer states and computations to the CPU.
|
||||||
```python
|
```python
|
||||||
zero = dict(
|
zero = dict(
|
||||||
|
level=3,
|
||||||
offload_optimizer_config=dict(
|
offload_optimizer_config=dict(
|
||||||
device='cpu',
|
device='cpu',
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
|
@ -49,6 +45,7 @@ Here we use `Adam` as the initial optimizer.
|
||||||
3. Save even more memory by offloading parameters to the CPU memory.
|
3. Save even more memory by offloading parameters to the CPU memory.
|
||||||
```python
|
```python
|
||||||
zero = dict(
|
zero = dict(
|
||||||
|
level=3,
|
||||||
offload_optimizer_config=dict(
|
offload_optimizer_config=dict(
|
||||||
device='cpu',
|
device='cpu',
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
|
@ -65,6 +62,7 @@ Here we use `Adam` as the initial optimizer.
|
||||||
4. Save even MORE memory by offloading to NVMe (if available on your system):
|
4. Save even MORE memory by offloading to NVMe (if available on your system):
|
||||||
```python
|
```python
|
||||||
zero = dict(
|
zero = dict(
|
||||||
|
level=3,
|
||||||
offload_optimizer_config=dict(
|
offload_optimizer_config=dict(
|
||||||
device='nvme',
|
device='nvme',
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
|
@ -81,7 +79,7 @@ Here we use `Adam` as the initial optimizer.
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
Note that `fp16` is automatically enabled when using ZeRO.
|
Note that `fp16` is automatically enabled when using ZeRO. This relies on `AMP_TYPE.NAIVE` in Colossal-AI AMP module.
|
||||||
|
|
||||||
### Training
|
### Training
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue