mirror of https://github.com/hpcaitech/ColossalAI
[doc] update booster tutorials (#3718)
* [booster] update booster tutorials#3717 * [booster] update booster tutorials#3717, fix * [booster] update booster tutorials#3717, update setup doc * [booster] update booster tutorials#3717, update setup doc * [booster] update booster tutorials#3717, update setup doc * [booster] update booster tutorials#3717, update setup doc * [booster] update booster tutorials#3717, update setup doc * [booster] update booster tutorials#3717, update setup doc * [booster] update booster tutorials#3717, rename colossalai booster.md * [booster] update booster tutorials#3717, rename colossalai booster.md * [booster] update booster tutorials#3717, rename colossalai booster.md * [booster] update booster tutorials#3717, fix * [booster] update booster tutorials#3717, fix * [booster] update tutorials#3717, update booster api doc * [booster] update tutorials#3717, modify file * [booster] update tutorials#3717, modify file * [booster] update tutorials#3717, modify file * [booster] update tutorials#3717, modify file * [booster] update tutorials#3717, modify file * [booster] update tutorials#3717, modify file * [booster] update tutorials#3717, modify file * [booster] update tutorials#3717, fix reference link * [booster] update tutorials#3717, fix reference link * [booster] update tutorials#3717, fix reference link * [booster] update tutorials#3717, fix reference link * [booster] update tutorials#3717, fix reference link * [booster] update tutorials#3717, fix reference link * [booster] update tutorials#3717, fix reference link * [booster] update tutorials#3713 * [booster] update tutorials#3713, modify filepull/3772/head
parent
05759839bd
commit
d449525acf
|
@ -32,7 +32,8 @@
|
||||||
"basics/engine_trainer",
|
"basics/engine_trainer",
|
||||||
"basics/configure_parallelization",
|
"basics/configure_parallelization",
|
||||||
"basics/model_checkpoint",
|
"basics/model_checkpoint",
|
||||||
"basics/colotensor_concept"
|
"basics/colotensor_concept",
|
||||||
|
"basics/booster_api"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
# Booster API
|
||||||
|
Author: [Mingyan Jiang](https://github.com/jiangmingyan)
|
||||||
|
|
||||||
|
**Prerequisite:**
|
||||||
|
- [Distributed Training](../concepts/distributed_training.md)
|
||||||
|
- [Colossal-AI Overview](../concepts/colossalai_overview.md)
|
||||||
|
|
||||||
|
**Example Code**
|
||||||
|
- [Train with Booster](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet/README.md)
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
In our new design, `colossalai.booster` replaces the role of `colossalai.initialize` to inject features into your training components (e.g. model, optimizer, dataloader) seamlessly. With these new APIs, you can integrate your model with our parallelism features more friendly. Also calling `colossalai.booster` is the standard procedure before you run into your training loops. In the sections below, I will cover how `colossalai.booster` works and what we should take note of.
|
||||||
|
|
||||||
|
### Plugin
|
||||||
|
Plugin is an important component that manages parallel configuration (eg: The gemini plugin encapsulates the gemini acceleration solution). Currently supported plugins are as follows:
|
||||||
|
|
||||||
|
***GeminiPlugin:*** This plugin wrapps the Gemini acceleration solution, that ZeRO with chunk-based memory management.
|
||||||
|
|
||||||
|
***TorchDDPPlugin:*** This plugin wrapps the DDP acceleration solution, it implements data parallelism at the module level which can run across multiple machines.
|
||||||
|
|
||||||
|
***LowLevelZeroPlugin:*** This plugin wraps the 1/2 stage of Zero Redundancy Optimizer. Stage 1 : Shards optimizer states across data parallel workers/GPUs. Stage 2 : Shards optimizer states + gradients across data parallel workers/GPUs.
|
||||||
|
|
||||||
|
### API of booster
|
||||||
|
|
||||||
|
|
||||||
|
{{ autodoc:colossalai.booster.Booster }}
|
||||||
|
|
||||||
|
{{ autodoc:colossalai.booster.Booster.boost }}
|
||||||
|
|
||||||
|
{{ autodoc:colossalai.booster.Booster.backward }}
|
||||||
|
|
||||||
|
{{ autodoc:colossalai.booster.Booster.no_sync }}
|
||||||
|
|
||||||
|
{{ autodoc:colossalai.booster.Booster.save_model }}
|
||||||
|
|
||||||
|
{{ autodoc:colossalai.booster.Booster.load_model }}
|
||||||
|
|
||||||
|
{{ autodoc:colossalai.booster.Booster.save_optimizer }}
|
||||||
|
|
||||||
|
{{ autodoc:colossalai.booster.Booster.load_optimizer }}
|
||||||
|
|
||||||
|
{{ autodoc:colossalai.booster.Booster.save_lr_scheduler }}
|
||||||
|
|
||||||
|
{{ autodoc:colossalai.booster.Booster.load_lr_scheduler }}
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
In a typical workflow, you should launch distributed environment at the beginning of training script and create objects needed (such as models, optimizers, loss function, data loaders etc.) firstly, then call `colossalai.booster` to inject features into these objects, After that, you can use our booster APIs and these returned objects to continue the rest of your training processes.
|
||||||
|
|
||||||
|
A pseudo-code example is like below:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from torch.optim import SGD
|
||||||
|
from torchvision.models import resnet18
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import TorchDDPPlugin
|
||||||
|
|
||||||
|
def train():
|
||||||
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||||
|
plugin = TorchDDPPlugin()
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
model = resnet18()
|
||||||
|
criterion = lambda x: x.mean()
|
||||||
|
optimizer = SGD((model.parameters()), lr=0.001)
|
||||||
|
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
|
||||||
|
model, optimizer, criterion, _, scheduler = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler)
|
||||||
|
|
||||||
|
x = torch.randn(4, 3, 224, 224)
|
||||||
|
x = x.to('cuda')
|
||||||
|
output = model(x)
|
||||||
|
loss = criterion(output)
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
optimizer.clip_grad_by_norm(1.0)
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
save_path = "./model"
|
||||||
|
booster.save_model(model, save_path, True, True, "", 10, use_safetensors=use_safetensors)
|
||||||
|
|
||||||
|
new_model = resnet18()
|
||||||
|
booster.load_model(new_model, save_path)
|
||||||
|
```
|
||||||
|
|
||||||
|
[more design details](https://github.com/hpcaitech/ColossalAI/discussions/3046)
|
||||||
|
|
||||||
|
|
||||||
|
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 booster_api.py -->
|
|
@ -87,14 +87,13 @@ import colossalai
|
||||||
args = colossalai.get_default_parser().parse_args()
|
args = colossalai.get_default_parser().parse_args()
|
||||||
|
|
||||||
# launch distributed environment
|
# launch distributed environment
|
||||||
colossalai.launch(config=<CONFIG>,
|
colossalai.launch(config=args.config,
|
||||||
rank=args.rank,
|
rank=args.rank,
|
||||||
world_size=args.world_size,
|
world_size=args.world_size,
|
||||||
host=args.host,
|
host=args.host,
|
||||||
port=args.port,
|
port=args.port,
|
||||||
backend=args.backend
|
backend=args.backend
|
||||||
)
|
)
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
@ -107,12 +106,21 @@ First, we need to set the launch method in our code. As this is a wrapper of the
|
||||||
use `colossalai.launch_from_torch`. The arguments required for distributed environment such as rank, world size, host and port are all set by the PyTorch
|
use `colossalai.launch_from_torch`. The arguments required for distributed environment such as rank, world size, host and port are all set by the PyTorch
|
||||||
launcher and can be read from the environment variable directly.
|
launcher and can be read from the environment variable directly.
|
||||||
|
|
||||||
|
config.py
|
||||||
|
```python
|
||||||
|
BATCH_SIZE = 512
|
||||||
|
LEARNING_RATE = 3e-3
|
||||||
|
WEIGHT_DECAY = 0.3
|
||||||
|
NUM_EPOCHS = 2
|
||||||
|
```
|
||||||
|
train.py
|
||||||
```python
|
```python
|
||||||
import colossalai
|
import colossalai
|
||||||
|
|
||||||
colossalai.launch_from_torch(
|
colossalai.launch_from_torch(
|
||||||
config=<CONFIG>,
|
config="./config.py",
|
||||||
)
|
)
|
||||||
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
Next, we can easily start multiple processes with `colossalai run` in your terminal. Below is an example to run the code
|
Next, we can easily start multiple processes with `colossalai run` in your terminal. Below is an example to run the code
|
||||||
|
|
|
@ -29,7 +29,7 @@ CUDA_EXT=1 pip install colossalai
|
||||||
|
|
||||||
## Download From Source
|
## Download From Source
|
||||||
|
|
||||||
> The version of Colossal-AI will be in line with the main branch of the repository. Feel free to raise an issue if you encounter any problem. :)
|
> The version of Colossal-AI will be in line with the main branch of the repository. Feel free to raise an issue if you encounter any problem.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
git clone https://github.com/hpcaitech/ColossalAI.git
|
git clone https://github.com/hpcaitech/ColossalAI.git
|
||||||
|
@ -39,13 +39,13 @@ cd ColossalAI
|
||||||
pip install -r requirements/requirements.txt
|
pip install -r requirements/requirements.txt
|
||||||
|
|
||||||
# install colossalai
|
# install colossalai
|
||||||
pip install .
|
CUDA_EXT=1 pip install .
|
||||||
```
|
```
|
||||||
|
|
||||||
If you don't want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer):
|
If you don't want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer), just don't specify the `CUDA_EXT`:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
CUDA_EXT=1 pip install .
|
pip install .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,89 @@
|
||||||
|
# booster 使用
|
||||||
|
作者: [Mingyan Jiang](https://github.com/jiangmingyan)
|
||||||
|
|
||||||
|
**预备知识:**
|
||||||
|
- [分布式训练](../concepts/distributed_training.md)
|
||||||
|
- [Colossal-AI 总览](../concepts/colossalai_overview.md)
|
||||||
|
|
||||||
|
**示例代码**
|
||||||
|
- [使用booster训练](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet/README.md)
|
||||||
|
|
||||||
|
## 简介
|
||||||
|
在我们的新设计中, `colossalai.booster` 代替 `colossalai.initialize` 将特征(例如,模型、优化器、数据加载器)无缝注入您的训练组件中。 使用booster API, 您可以更友好地将我们的并行策略整合到待训练模型中. 调用 `colossalai.booster` 是您进入训练循环前的基本操作。
|
||||||
|
在下面的章节中,我们将介绍 `colossalai.booster` 是如何工作的以及使用时我们要注意的细节。
|
||||||
|
|
||||||
|
### Booster插件
|
||||||
|
Booster插件是管理并行配置的重要组件(eg:gemini插件封装了gemini加速方案)。目前支持的插件如下:
|
||||||
|
|
||||||
|
***GeminiPlugin:*** GeminiPlugin插件封装了 gemini 加速解决方案,即基于块内存管理的 ZeRO优化方案。
|
||||||
|
|
||||||
|
***TorchDDPPlugin:*** TorchDDPPlugin插件封装了DDP加速方案,实现了模型级别的数据并行,可以跨多机运行。
|
||||||
|
|
||||||
|
***LowLevelZeroPlugin:*** LowLevelZeroPlugin插件封装了零冗余优化器的 1/2 阶段。阶段 1:切分优化器参数,分发到各并发进程或并发GPU上。阶段 2:切分优化器参数及梯度,分发到各并发进程或并发GPU上。
|
||||||
|
|
||||||
|
### Booster接口
|
||||||
|
|
||||||
|
{{ autodoc:colossalai.booster.Booster }}
|
||||||
|
|
||||||
|
{{ autodoc:colossalai.booster.Booster.boost }}
|
||||||
|
|
||||||
|
{{ autodoc:colossalai.booster.Booster.backward }}
|
||||||
|
|
||||||
|
{{ autodoc:colossalai.booster.Booster.no_sync }}
|
||||||
|
|
||||||
|
{{ autodoc:colossalai.booster.Booster.save_model }}
|
||||||
|
|
||||||
|
{{ autodoc:colossalai.booster.Booster.load_model }}
|
||||||
|
|
||||||
|
{{ autodoc:colossalai.booster.Booster.save_optimizer }}
|
||||||
|
|
||||||
|
{{ autodoc:colossalai.booster.Booster.load_optimizer }}
|
||||||
|
|
||||||
|
{{ autodoc:colossalai.booster.Booster.save_lr_scheduler }}
|
||||||
|
|
||||||
|
{{ autodoc:colossalai.booster.Booster.load_lr_scheduler }}
|
||||||
|
|
||||||
|
## 使用方法及示例
|
||||||
|
|
||||||
|
在使用colossalai训练时,首先需要在训练脚本的开头启动分布式环境,并创建需要使用的模型、优化器、损失函数、数据加载器等对象。之后,调用`colossalai.booster` 将特征注入到这些对象中,您就可以使用我们的booster API去进行您接下来的训练流程。
|
||||||
|
|
||||||
|
以下是一个伪代码示例,将展示如何使用我们的booster API进行模型训练:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from torch.optim import SGD
|
||||||
|
from torchvision.models import resnet18
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import TorchDDPPlugin
|
||||||
|
|
||||||
|
def train():
|
||||||
|
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||||
|
plugin = TorchDDPPlugin()
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
model = resnet18()
|
||||||
|
criterion = lambda x: x.mean()
|
||||||
|
optimizer = SGD((model.parameters()), lr=0.001)
|
||||||
|
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
|
||||||
|
model, optimizer, criterion, _, scheduler = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler)
|
||||||
|
|
||||||
|
x = torch.randn(4, 3, 224, 224)
|
||||||
|
x = x.to('cuda')
|
||||||
|
output = model(x)
|
||||||
|
loss = criterion(output)
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
optimizer.clip_grad_by_norm(1.0)
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
save_path = "./model"
|
||||||
|
booster.save_model(model, save_path, True, True, "", 10, use_safetensors=use_safetensors)
|
||||||
|
|
||||||
|
new_model = resnet18()
|
||||||
|
booster.load_model(new_model, save_path)
|
||||||
|
```
|
||||||
|
|
||||||
|
[更多的设计细节请参考](https://github.com/hpcaitech/ColossalAI/discussions/3046)
|
||||||
|
|
||||||
|
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 booster_api.py -->
|
|
@ -74,7 +74,7 @@ import colossalai
|
||||||
args = colossalai.get_default_parser().parse_args()
|
args = colossalai.get_default_parser().parse_args()
|
||||||
|
|
||||||
# launch distributed environment
|
# launch distributed environment
|
||||||
colossalai.launch(config=<CONFIG>,
|
colossalai.launch(config=args.config,
|
||||||
rank=args.rank,
|
rank=args.rank,
|
||||||
world_size=args.world_size,
|
world_size=args.world_size,
|
||||||
host=args.host,
|
host=args.host,
|
||||||
|
@ -93,12 +93,21 @@ PyTorch自带的启动器需要在每个节点上都启动命令才能启动多
|
||||||
首先,我们需要在代码里指定我们的启动方式。由于这个启动器是PyTorch启动器的封装,那么我们自然而然应该使用`colossalai.launch_from_torch`。
|
首先,我们需要在代码里指定我们的启动方式。由于这个启动器是PyTorch启动器的封装,那么我们自然而然应该使用`colossalai.launch_from_torch`。
|
||||||
分布式环境所需的参数,如 rank, world size, host 和 port 都是由 PyTorch 启动器设置的,可以直接从环境变量中读取。
|
分布式环境所需的参数,如 rank, world size, host 和 port 都是由 PyTorch 启动器设置的,可以直接从环境变量中读取。
|
||||||
|
|
||||||
|
config.py
|
||||||
|
```python
|
||||||
|
BATCH_SIZE = 512
|
||||||
|
LEARNING_RATE = 3e-3
|
||||||
|
WEIGHT_DECAY = 0.3
|
||||||
|
NUM_EPOCHS = 2
|
||||||
|
```
|
||||||
|
train.py
|
||||||
```python
|
```python
|
||||||
import colossalai
|
import colossalai
|
||||||
|
|
||||||
colossalai.launch_from_torch(
|
colossalai.launch_from_torch(
|
||||||
config=<CONFIG>,
|
config="./config.py",
|
||||||
)
|
)
|
||||||
|
...
|
||||||
```
|
```
|
||||||
|
|
||||||
接下来,我们可以轻松地在终端使用`colossalai run`来启动训练。下面的命令可以在当前机器上启动一个4卡的训练任务。
|
接下来,我们可以轻松地在终端使用`colossalai run`来启动训练。下面的命令可以在当前机器上启动一个4卡的训练任务。
|
||||||
|
|
|
@ -28,7 +28,7 @@ CUDA_EXT=1 pip install colossalai
|
||||||
|
|
||||||
## 从源安装
|
## 从源安装
|
||||||
|
|
||||||
> 此文档将与版本库的主分支保持一致。如果您遇到任何问题,欢迎给我们提 issue :)
|
> 此文档将与版本库的主分支保持一致。如果您遇到任何问题,欢迎给我们提 issue。
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
git clone https://github.com/hpcaitech/ColossalAI.git
|
git clone https://github.com/hpcaitech/ColossalAI.git
|
||||||
|
@ -38,13 +38,13 @@ cd ColossalAI
|
||||||
pip install -r requirements/requirements.txt
|
pip install -r requirements/requirements.txt
|
||||||
|
|
||||||
# install colossalai
|
# install colossalai
|
||||||
|
CUDA_EXT=1 pip install .
|
||||||
|
```
|
||||||
|
|
||||||
|
如果您不想安装和启用 CUDA 内核融合(使用融合优化器时强制安装),您可以不添加`CUDA_EXT=1`:
|
||||||
|
|
||||||
|
```shell
|
||||||
pip install .
|
pip install .
|
||||||
```
|
```
|
||||||
|
|
||||||
如果您不想安装和启用 CUDA 内核融合(使用融合优化器时强制安装):
|
|
||||||
|
|
||||||
```shell
|
|
||||||
NO_CUDA_EXT=1 pip install .
|
|
||||||
```
|
|
||||||
|
|
||||||
<!-- doc-test-command: echo "installation.md does not need test" -->
|
<!-- doc-test-command: echo "installation.md does not need test" -->
|
||||||
|
|
Loading…
Reference in New Issue