[doc] update booster tutorials (#3718)

* [booster] update booster tutorials#3717

* [booster] update booster tutorials#3717, fix

* [booster] update booster tutorials#3717, update setup doc

* [booster] update booster tutorials#3717, update setup doc

* [booster] update booster tutorials#3717, update setup doc

* [booster] update booster tutorials#3717, update setup doc

* [booster] update booster tutorials#3717, update setup doc

* [booster] update booster tutorials#3717, update setup doc

* [booster] update booster tutorials#3717, rename colossalai booster.md

* [booster] update booster tutorials#3717, rename colossalai booster.md

* [booster] update booster tutorials#3717, rename colossalai booster.md

* [booster] update booster tutorials#3717, fix

* [booster] update booster tutorials#3717, fix

* [booster] update tutorials#3717, update booster api doc

* [booster] update tutorials#3717, modify file

* [booster] update tutorials#3717, modify file

* [booster] update tutorials#3717, modify file

* [booster] update tutorials#3717, modify file

* [booster] update tutorials#3717, modify file

* [booster] update tutorials#3717, modify file

* [booster] update tutorials#3717, modify file

* [booster] update tutorials#3717, fix reference link

* [booster] update tutorials#3717, fix reference link

* [booster] update tutorials#3717, fix reference link

* [booster] update tutorials#3717, fix reference link

* [booster] update tutorials#3717, fix reference link

* [booster] update tutorials#3717, fix reference link

* [booster] update tutorials#3717, fix reference link

* [booster] update tutorials#3713

* [booster] update tutorials#3713, modify file
pull/3772/head
jiangmingyan 2023-05-18 11:41:56 +08:00 committed by GitHub
parent 05759839bd
commit d449525acf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 213 additions and 17 deletions

View File

@ -32,7 +32,8 @@
"basics/engine_trainer",
"basics/configure_parallelization",
"basics/model_checkpoint",
"basics/colotensor_concept"
"basics/colotensor_concept",
"basics/booster_api"
]
},
{

View File

@ -0,0 +1,89 @@
# Booster API
Author: [Mingyan Jiang](https://github.com/jiangmingyan)
**Prerequisite:**
- [Distributed Training](../concepts/distributed_training.md)
- [Colossal-AI Overview](../concepts/colossalai_overview.md)
**Example Code**
- [Train with Booster](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet/README.md)
## Introduction
In our new design, `colossalai.booster` replaces the role of `colossalai.initialize` to inject features into your training components (e.g. model, optimizer, dataloader) seamlessly. With these new APIs, you can integrate your model with our parallelism features more friendly. Also calling `colossalai.booster` is the standard procedure before you run into your training loops. In the sections below, I will cover how `colossalai.booster` works and what we should take note of.
### Plugin
Plugin is an important component that manages parallel configuration (eg: The gemini plugin encapsulates the gemini acceleration solution). Currently supported plugins are as follows:
***GeminiPlugin:*** This plugin wrapps the Gemini acceleration solution, that ZeRO with chunk-based memory management.
***TorchDDPPlugin:*** This plugin wrapps the DDP acceleration solution, it implements data parallelism at the module level which can run across multiple machines.
***LowLevelZeroPlugin:*** This plugin wraps the 1/2 stage of Zero Redundancy Optimizer. Stage 1 : Shards optimizer states across data parallel workers/GPUs. Stage 2 : Shards optimizer states + gradients across data parallel workers/GPUs.
### API of booster
{{ autodoc:colossalai.booster.Booster }}
{{ autodoc:colossalai.booster.Booster.boost }}
{{ autodoc:colossalai.booster.Booster.backward }}
{{ autodoc:colossalai.booster.Booster.no_sync }}
{{ autodoc:colossalai.booster.Booster.save_model }}
{{ autodoc:colossalai.booster.Booster.load_model }}
{{ autodoc:colossalai.booster.Booster.save_optimizer }}
{{ autodoc:colossalai.booster.Booster.load_optimizer }}
{{ autodoc:colossalai.booster.Booster.save_lr_scheduler }}
{{ autodoc:colossalai.booster.Booster.load_lr_scheduler }}
## Usage
In a typical workflow, you should launch distributed environment at the beginning of training script and create objects needed (such as models, optimizers, loss function, data loaders etc.) firstly, then call `colossalai.booster` to inject features into these objects, After that, you can use our booster APIs and these returned objects to continue the rest of your training processes.
A pseudo-code example is like below:
```python
import torch
from torch.optim import SGD
from torchvision.models import resnet18
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
def train():
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = resnet18()
criterion = lambda x: x.mean()
optimizer = SGD((model.parameters()), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
model, optimizer, criterion, _, scheduler = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler)
x = torch.randn(4, 3, 224, 224)
x = x.to('cuda')
output = model(x)
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
scheduler.step()
save_path = "./model"
booster.save_model(model, save_path, True, True, "", 10, use_safetensors=use_safetensors)
new_model = resnet18()
booster.load_model(new_model, save_path)
```
[more design details](https://github.com/hpcaitech/ColossalAI/discussions/3046)
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 booster_api.py -->

View File

@ -87,14 +87,13 @@ import colossalai
args = colossalai.get_default_parser().parse_args()
# launch distributed environment
colossalai.launch(config=<CONFIG>,
colossalai.launch(config=args.config,
rank=args.rank,
world_size=args.world_size,
host=args.host,
port=args.port,
backend=args.backend
)
```
@ -107,12 +106,21 @@ First, we need to set the launch method in our code. As this is a wrapper of the
use `colossalai.launch_from_torch`. The arguments required for distributed environment such as rank, world size, host and port are all set by the PyTorch
launcher and can be read from the environment variable directly.
config.py
```python
BATCH_SIZE = 512
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
NUM_EPOCHS = 2
```
train.py
```python
import colossalai
colossalai.launch_from_torch(
config=<CONFIG>,
config="./config.py",
)
...
```
Next, we can easily start multiple processes with `colossalai run` in your terminal. Below is an example to run the code

View File

@ -29,7 +29,7 @@ CUDA_EXT=1 pip install colossalai
## Download From Source
> The version of Colossal-AI will be in line with the main branch of the repository. Feel free to raise an issue if you encounter any problem. :)
> The version of Colossal-AI will be in line with the main branch of the repository. Feel free to raise an issue if you encounter any problem.
```shell
git clone https://github.com/hpcaitech/ColossalAI.git
@ -39,13 +39,13 @@ cd ColossalAI
pip install -r requirements/requirements.txt
# install colossalai
pip install .
CUDA_EXT=1 pip install .
```
If you don't want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer):
If you don't want to install and enable CUDA kernel fusion (compulsory installation when using fused optimizer), just don't specify the `CUDA_EXT`:
```shell
CUDA_EXT=1 pip install .
pip install .
```

View File

@ -0,0 +1,89 @@
# booster 使用
作者: [Mingyan Jiang](https://github.com/jiangmingyan)
**预备知识:**
- [分布式训练](../concepts/distributed_training.md)
- [Colossal-AI 总览](../concepts/colossalai_overview.md)
**示例代码**
- [使用booster训练](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet/README.md)
## 简介
在我们的新设计中, `colossalai.booster` 代替 `colossalai.initialize` 将特征(例如,模型、优化器、数据加载器)无缝注入您的训练组件中。 使用booster API, 您可以更友好地将我们的并行策略整合到待训练模型中. 调用 `colossalai.booster` 是您进入训练循环前的基本操作。
在下面的章节中,我们将介绍 `colossalai.booster` 是如何工作的以及使用时我们要注意的细节。
### Booster插件
Booster插件是管理并行配置的重要组件eggemini插件封装了gemini加速方案。目前支持的插件如下
***GeminiPlugin:*** GeminiPlugin插件封装了 gemini 加速解决方案,即基于块内存管理的 ZeRO优化方案。
***TorchDDPPlugin:*** TorchDDPPlugin插件封装了DDP加速方案实现了模型级别的数据并行可以跨多机运行。
***LowLevelZeroPlugin:*** LowLevelZeroPlugin插件封装了零冗余优化器的 1/2 阶段。阶段 1切分优化器参数分发到各并发进程或并发GPU上。阶段 2切分优化器参数及梯度分发到各并发进程或并发GPU上。
### Booster接口
{{ autodoc:colossalai.booster.Booster }}
{{ autodoc:colossalai.booster.Booster.boost }}
{{ autodoc:colossalai.booster.Booster.backward }}
{{ autodoc:colossalai.booster.Booster.no_sync }}
{{ autodoc:colossalai.booster.Booster.save_model }}
{{ autodoc:colossalai.booster.Booster.load_model }}
{{ autodoc:colossalai.booster.Booster.save_optimizer }}
{{ autodoc:colossalai.booster.Booster.load_optimizer }}
{{ autodoc:colossalai.booster.Booster.save_lr_scheduler }}
{{ autodoc:colossalai.booster.Booster.load_lr_scheduler }}
## 使用方法及示例
在使用colossalai训练时首先需要在训练脚本的开头启动分布式环境并创建需要使用的模型、优化器、损失函数、数据加载器等对象。之后调用`colossalai.booster` 将特征注入到这些对象中您就可以使用我们的booster API去进行您接下来的训练流程。
以下是一个伪代码示例将展示如何使用我们的booster API进行模型训练:
```python
import torch
from torch.optim import SGD
from torchvision.models import resnet18
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
def train():
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = resnet18()
criterion = lambda x: x.mean()
optimizer = SGD((model.parameters()), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
model, optimizer, criterion, _, scheduler = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler)
x = torch.randn(4, 3, 224, 224)
x = x.to('cuda')
output = model(x)
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
scheduler.step()
save_path = "./model"
booster.save_model(model, save_path, True, True, "", 10, use_safetensors=use_safetensors)
new_model = resnet18()
booster.load_model(new_model, save_path)
```
[更多的设计细节请参考](https://github.com/hpcaitech/ColossalAI/discussions/3046)
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 booster_api.py -->

View File

@ -74,7 +74,7 @@ import colossalai
args = colossalai.get_default_parser().parse_args()
# launch distributed environment
colossalai.launch(config=<CONFIG>,
colossalai.launch(config=args.config,
rank=args.rank,
world_size=args.world_size,
host=args.host,
@ -93,12 +93,21 @@ PyTorch自带的启动器需要在每个节点上都启动命令才能启动多
首先我们需要在代码里指定我们的启动方式。由于这个启动器是PyTorch启动器的封装那么我们自然而然应该使用`colossalai.launch_from_torch`。
分布式环境所需的参数,如 rank, world size, host 和 port 都是由 PyTorch 启动器设置的,可以直接从环境变量中读取。
config.py
```python
BATCH_SIZE = 512
LEARNING_RATE = 3e-3
WEIGHT_DECAY = 0.3
NUM_EPOCHS = 2
```
train.py
```python
import colossalai
colossalai.launch_from_torch(
config=<CONFIG>,
config="./config.py",
)
...
```
接下来,我们可以轻松地在终端使用`colossalai run`来启动训练。下面的命令可以在当前机器上启动一个4卡的训练任务。

View File

@ -28,7 +28,7 @@ CUDA_EXT=1 pip install colossalai
## 从源安装
> 此文档将与版本库的主分支保持一致。如果您遇到任何问题,欢迎给我们提 issue :)
> 此文档将与版本库的主分支保持一致。如果您遇到任何问题,欢迎给我们提 issue
```shell
git clone https://github.com/hpcaitech/ColossalAI.git
@ -38,13 +38,13 @@ cd ColossalAI
pip install -r requirements/requirements.txt
# install colossalai
CUDA_EXT=1 pip install .
```
如果您不想安装和启用 CUDA 内核融合(使用融合优化器时强制安装),您可以不添加`CUDA_EXT=1`
```shell
pip install .
```
如果您不想安装和启用 CUDA 内核融合(使用融合优化器时强制安装):
```shell
NO_CUDA_EXT=1 pip install .
```
<!-- doc-test-command: echo "installation.md does not need test" -->