mirror of https://github.com/hpcaitech/ColossalAI
[doc] Update booster user documents. (#4669)
* update booster_api.md * update booster_checkpoint.md * update booster_plugins.md * move transformers importing inside function * fix Dict typing * fix autodoc bug * small fixpull/4658/head
parent
bce0f16702
commit
1d454733c4
|
@ -1,6 +1,6 @@
|
|||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Iterator, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -24,29 +24,31 @@ class Booster:
|
|||
Booster is a high-level API for training neural networks. It provides a unified interface for
|
||||
training with different precision, accelerator, and plugin.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
colossalai.launch(...)
|
||||
plugin = GeminiPlugin(...)
|
||||
booster = Booster(precision='fp16', plugin=plugin)
|
||||
|
||||
model = GPT2()
|
||||
optimizer = HybridAdam(model.parameters())
|
||||
dataloader = Dataloader(Dataset)
|
||||
lr_scheduler = LinearWarmupScheduler()
|
||||
criterion = GPTLMLoss()
|
||||
```python
|
||||
# Following is pseudocode
|
||||
|
||||
model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader)
|
||||
colossalai.launch(...)
|
||||
plugin = GeminiPlugin(...)
|
||||
booster = Booster(precision='fp16', plugin=plugin)
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
for input_ids, attention_mask in dataloader:
|
||||
outputs = model(input_ids, attention_mask)
|
||||
loss = criterion(outputs.logits, input_ids)
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
model = GPT2()
|
||||
optimizer = HybridAdam(model.parameters())
|
||||
dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
|
||||
lr_scheduler = LinearWarmupScheduler()
|
||||
criterion = GPTLMLoss()
|
||||
|
||||
model, optimizer, criterion, dataloader, lr_scheduler = booster.boost(model, optimizer, criterion, dataloader, lr_scheduler)
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
for input_ids, attention_mask in dataloader:
|
||||
outputs = model(input_ids.cuda(), attention_mask.cuda())
|
||||
loss = criterion(outputs.logits, input_ids)
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
```
|
||||
|
||||
Args:
|
||||
device (str or torch.device): The device to run the training. Default: None.
|
||||
|
@ -60,7 +62,7 @@ class Booster:
|
|||
|
||||
def __init__(self,
|
||||
device: Optional[str] = None,
|
||||
mixed_precision: Union[MixedPrecision, str] = None,
|
||||
mixed_precision: Optional[Union[MixedPrecision, str]] = None,
|
||||
plugin: Optional[Plugin] = None) -> None:
|
||||
if plugin is not None:
|
||||
assert isinstance(
|
||||
|
@ -110,14 +112,19 @@ class Booster:
|
|||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]:
|
||||
"""
|
||||
Boost the model, optimizer, criterion, lr_scheduler, and dataloader.
|
||||
Wrap and inject features to the passed in model, optimizer, criterion, lr_scheduler, and dataloader.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be boosted.
|
||||
optimizer (Optimizer): The optimizer to be boosted.
|
||||
criterion (Callable): The criterion to be boosted.
|
||||
dataloader (DataLoader): The dataloader to be boosted.
|
||||
lr_scheduler (LRScheduler): The lr_scheduler to be boosted.
|
||||
model (nn.Module): Convert model into a wrapped model for distributive training.
|
||||
The model might be decorated or partitioned by plugin's strategy after execution of this method.
|
||||
optimizer (Optimizer, optional): Convert optimizer into a wrapped optimizer for distributive training.
|
||||
The optimizer's param groups or states might be decorated or partitioned by plugin's strategy after execution of this method. Defaults to None.
|
||||
criterion (Callable, optional): The function that calculates loss. Defaults to None.
|
||||
dataloader (DataLoader, optional): The prepared dataloader for training. Defaults to None.
|
||||
lr_scheduler (LRScheduler, optional): The learning scheduler for training. Defaults to None.
|
||||
|
||||
Returns:
|
||||
List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: The list of boosted input arguments.
|
||||
"""
|
||||
# TODO(FrankLeeeee): consider multi-model and multi-optimizer case
|
||||
# TODO(FrankLeeeee): consider multi-dataloader case
|
||||
|
@ -138,10 +145,10 @@ class Booster:
|
|||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
def backward(self, loss: torch.Tensor, optimizer: Optimizer) -> None:
|
||||
"""Backward pass.
|
||||
"""Execution of backward during training step.
|
||||
|
||||
Args:
|
||||
loss (torch.Tensor): The loss to be backpropagated.
|
||||
loss (torch.Tensor): The loss for backpropagation.
|
||||
optimizer (Optimizer): The optimizer to be updated.
|
||||
"""
|
||||
# TODO(frank lee): implement this method with plugin
|
||||
|
@ -153,9 +160,31 @@ class Booster:
|
|||
criterion: Callable[[Any, Any], torch.Tensor],
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
return_loss: bool = True,
|
||||
return_outputs: bool = False) -> dict:
|
||||
# run pipeline forward backward pass
|
||||
# return loss or outputs if needed
|
||||
return_outputs: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute forward & backward when utilizing pipeline parallel.
|
||||
Return loss or Huggingface style model outputs if needed.
|
||||
|
||||
Warning: This function is tailored for the scenario of pipeline parallel.
|
||||
As a result, please don't do the forward/backward pass in the conventional way (model(input)/loss.backward())
|
||||
when doing pipeline parallel training with booster, which will cause unexpected errors.
|
||||
|
||||
Args:
|
||||
data_iter(Iterator): The iterator for getting the next batch of data. Usually there are two ways to obtain this argument:
|
||||
1. wrap the dataloader to iterator through: iter(dataloader)
|
||||
2. get the next batch from dataloader, and wrap this batch to iterator: iter([batch])
|
||||
model (nn.Module): The model to execute forward/backward, it should be a model wrapped by a plugin that supports pipeline.
|
||||
criterion: (Callable[[Any, Any], torch.Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
|
||||
'lambda y, x: loss_fn(y)' can turn a normal loss function into a valid two-argument criterion here.
|
||||
optimizer (Optimizer, optional): The optimizer for execution of backward. Can be None when only doing forward (i.e. evaluation). Defaults to None.
|
||||
return_loss (bool, optional): Whether to return loss in the dict returned by this method. Defaults to True.
|
||||
return_output (bool, optional): Whether to return Huggingface style model outputs in the dict returned by this method. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Output dict in the form of {'loss': ..., 'outputs': ...}.
|
||||
ret_dict['loss'] is the loss of forward if return_loss is set to True, else None.
|
||||
ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None.
|
||||
"""
|
||||
assert isinstance(self.plugin,
|
||||
PipelinePluginBase), f'The plugin {self.plugin.__class__.__name__} does not support pipeline.'
|
||||
return self.plugin.execute_pipeline(data_iter, model, criterion, optimizer, return_loss, return_outputs)
|
||||
|
@ -175,7 +204,7 @@ class Booster:
|
|||
assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
||||
return self.plugin.no_sync(model, optimizer)
|
||||
|
||||
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
|
||||
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
|
||||
"""Load model from checkpoint.
|
||||
|
||||
Args:
|
||||
|
@ -195,7 +224,7 @@ class Booster:
|
|||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
use_safetensors: bool = False) -> None:
|
||||
"""Save model to checkpoint.
|
||||
|
||||
Args:
|
||||
|
@ -203,7 +232,7 @@ class Booster:
|
|||
checkpoint (str): Path to the checkpoint. It must be a local path.
|
||||
It is a file path if ``shard=False``. Otherwise, it is a directory path.
|
||||
shard (bool, optional): Whether to save checkpoint a sharded way.
|
||||
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
|
||||
If true, the checkpoint will be a folder with the same format as Huggingface transformers checkpoint. Otherwise, it will be a single file. Defaults to False.
|
||||
gather_dtensor (bool, optional): whether to gather the distributed tensor to the first device. Default: True.
|
||||
prefix (str, optional): A prefix added to parameter and buffer
|
||||
names to compose the keys in state_dict. Defaults to None.
|
||||
|
@ -218,7 +247,7 @@ class Booster:
|
|||
size_per_shard=size_per_shard,
|
||||
use_safetensors=use_safetensors)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str) -> None:
|
||||
"""Load optimizer from checkpoint.
|
||||
|
||||
Args:
|
||||
|
@ -237,7 +266,7 @@ class Booster:
|
|||
shard: bool = False,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024):
|
||||
size_per_shard: int = 1024) -> None:
|
||||
"""
|
||||
Save optimizer to checkpoint.
|
||||
|
||||
|
@ -254,7 +283,7 @@ class Booster:
|
|||
"""
|
||||
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard)
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
|
||||
"""Save lr scheduler to checkpoint.
|
||||
|
||||
Args:
|
||||
|
@ -263,7 +292,7 @@ class Booster:
|
|||
"""
|
||||
self.checkpoint_io.save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
def load_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str) -> None:
|
||||
"""Load lr scheduler from checkpoint.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Booster API
|
||||
|
||||
Author: [Mingyan Jiang](https://github.com/jiangmingyan) [Jianghai Chen](https://github.com/CjhHa1)
|
||||
Author: [Mingyan Jiang](https://github.com/jiangmingyan), [Jianghai Chen](https://github.com/CjhHa1), [Baizhou Zhang](https://github.com/Fridge003)
|
||||
|
||||
**Prerequisite:**
|
||||
|
||||
|
@ -9,32 +9,35 @@ Author: [Mingyan Jiang](https://github.com/jiangmingyan) [Jianghai Chen](https:/
|
|||
|
||||
**Example Code**
|
||||
|
||||
- [Train with Booster](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet/README.md)
|
||||
- [Train with Booster](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet)
|
||||
|
||||
## Introduction
|
||||
|
||||
In our new design, `colossalai.booster` replaces the role of `colossalai.initialize` to inject features into your training components (e.g. model, optimizer, dataloader) seamlessly. With these new APIs, you can integrate your model with our parallelism features more friendly. Also calling `colossalai.booster` is the standard procedure before you run into your training loops. In the sections below, I will cover how `colossalai.booster` works and what we should take note of.
|
||||
In our new design, `colossalai.booster` replaces the role of `colossalai.initialize` to inject features into your training components (e.g. model, optimizer, dataloader) seamlessly. With these new APIs, you can integrate your model with our parallelism features more friendly. Also, calling `colossalai.booster` is the standard procedure before you run into your training loops. In the sections below, we will cover how `colossalai.booster` works and what we should take note of.
|
||||
|
||||
### Plugin
|
||||
|
||||
Plugin is an important component that manages parallel configuration (eg: The gemini plugin encapsulates the gemini acceleration solution). Currently supported plugins are as follows:
|
||||
|
||||
**_HybridParallelPlugin:_** This plugin wraps the hybrid parallel training acceleration solution. It provides an interface for any combination of tensor parallel, pipeline parallel and data parallel strategies including DDP and ZeRO.
|
||||
|
||||
**_GeminiPlugin:_** This plugin wraps the Gemini acceleration solution, that ZeRO with chunk-based memory management.
|
||||
|
||||
**_TorchDDPPlugin:_** This plugin wraps the DDP acceleration solution of Pytorch. It implements data parallelism at the module level which can run across multiple machines.
|
||||
**_TorchDDPPlugin:_** This plugin wraps the DDP acceleration solution of Pytorch. It implements data parallel at the module level which can run across multiple machines.
|
||||
|
||||
**_LowLevelZeroPlugin:_** This plugin wraps the 1/2 stage of Zero Redundancy Optimizer. Stage 1 : Shards optimizer states across data parallel workers/GPUs. Stage 2 : Shards optimizer states + gradients across data parallel workers/GPUs.
|
||||
|
||||
|
||||
**_TorchFSDPPlugin:_** This plugin wraps the FSDP acceleration solution of Pytorch and can be used to train models with zero-dp.
|
||||
|
||||
More details about usages of each plugin can be found in chapter [Booster Plugins](./booster_plugins.md).
|
||||
|
||||
### API of booster
|
||||
|
||||
{{ autodoc:colossalai.booster.Booster }}
|
||||
|
||||
## Usage
|
||||
|
||||
In a typical workflow, you should launch distributed environment at the beginning of training script and create objects needed (such as models, optimizers, loss function, data loaders etc.) firstly, then call `colossalai.booster` to inject features into these objects, After that, you can use our booster APIs and these returned objects to continue the rest of your training processes.
|
||||
In a typical workflow, you should launch distributed environment at the beginning of training script and create objects needed (such as models, optimizers, loss function, data loaders etc.) firstly, then call `booster.boost` to inject features into these objects, After that, you can use our booster APIs and these returned objects to continue the rest of your training processes.
|
||||
|
||||
A pseudo-code example is like below:
|
||||
|
||||
|
@ -48,15 +51,21 @@ from colossalai.booster import Booster
|
|||
from colossalai.booster.plugin import TorchDDPPlugin
|
||||
|
||||
def train():
|
||||
# launch colossalai
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
|
||||
# create plugin and objects for training
|
||||
plugin = TorchDDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
model = resnet18()
|
||||
criterion = lambda x: x.mean()
|
||||
optimizer = SGD((model.parameters()), lr=0.001)
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
|
||||
|
||||
# use booster.boost to wrap the training objects
|
||||
model, optimizer, criterion, _, scheduler = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler)
|
||||
|
||||
# do training as normal, except that the backward should be called by booster
|
||||
x = torch.randn(4, 3, 224, 224)
|
||||
x = x.to('cuda')
|
||||
output = model(x)
|
||||
|
@ -65,14 +74,16 @@ def train():
|
|||
optimizer.clip_grad_by_norm(1.0)
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# checkpointing using booster api
|
||||
save_path = "./model"
|
||||
booster.save_model(model, save_path, True, True, "", 10, use_safetensors=use_safetensors)
|
||||
booster.save_model(model, save_path, shard=True, size_per_shard=10, use_safetensors=True)
|
||||
|
||||
new_model = resnet18()
|
||||
booster.load_model(new_model, save_path)
|
||||
```
|
||||
|
||||
[more design details](https://github.com/hpcaitech/ColossalAI/discussions/3046)
|
||||
For more design details please see [this page](https://github.com/hpcaitech/ColossalAI/discussions/3046).
|
||||
|
||||
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 booster_api.py -->
|
||||
|
|
|
@ -13,7 +13,7 @@ We've introduced the [Booster API](./booster_api.md) in the previous tutorial. I
|
|||
|
||||
{{ autodoc:colossalai.booster.Booster.save_model }}
|
||||
|
||||
Model must be boosted by `colossalai.booster.Booster` before saving. `checkpoint` is the path to saved checkpoint. It can be a file, if `shard=False`. Otherwise, it should be a directory. If `shard=True`, the checkpoint will be saved in a sharded way. This is useful when the checkpoint is too large to be saved in a single file. Our sharded checkpoint format is compatible with [huggingface/transformers](https://github.com/huggingface/transformers).
|
||||
Model must be boosted by `colossalai.booster.Booster` before saving. `checkpoint` is the path to saved checkpoint. It can be a file, if `shard=False`. Otherwise, it should be a directory. If `shard=True`, the checkpoint will be saved in a sharded way. This is useful when the checkpoint is too large to be saved in a single file. Our sharded checkpoint format is compatible with [huggingface/transformers](https://github.com/huggingface/transformers), so you can use huggingface `from_pretrained` method to load model from our sharded checkpoint.
|
||||
|
||||
{{ autodoc:colossalai.booster.Booster.load_model }}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Booster Plugins
|
||||
|
||||
Author: [Hongxin Liu](https://github.com/ver217)
|
||||
Author: [Hongxin Liu](https://github.com/ver217), [Baizhou Zhang](https://github.com/Fridge003)
|
||||
|
||||
**Prerequisite:**
|
||||
- [Booster API](./booster_api.md)
|
||||
|
@ -15,6 +15,7 @@ We currently provide the following plugins:
|
|||
- [Gemini Plugin](#gemini-plugin): It wraps the [Gemini](../features/zero_with_chunk.md) which implements Zero-3 with chunk-based and heterogeneous memory management.
|
||||
- [Torch DDP Plugin](#torch-ddp-plugin): It is a wrapper of `torch.nn.parallel.DistributedDataParallel` and can be used to train models with data parallelism.
|
||||
- [Torch FSDP Plugin](#torch-fsdp-plugin): It is a wrapper of `torch.distributed.fsdp.FullyShardedDataParallel` and can be used to train models with zero-dp.
|
||||
- [Hybrid Pararllel Plugin](#hybrid-parallel-plugin): It provides a tidy interface that integrates the power of Shardformer, pipeline manager, mixied precision training, TorchDDP and Zero stage 1/2 feature. With this plugin, transformer models can be easily trained with any combination of tensor parallel, pipeline parallel and data parallel (DDP/Zero) efficiently, along with various kinds of optimization tools for acceleration and memory saving. Detailed information about supported parallel strategies and optimization tools is explained in the section below.
|
||||
|
||||
More plugins are coming soon.
|
||||
|
||||
|
@ -43,8 +44,6 @@ We've tested compatibility on some famous models, following models may not be su
|
|||
|
||||
Compatibility problems will be fixed in the future.
|
||||
|
||||
> ⚠ This plugin can only load optimizer checkpoint saved by itself with the same number of processes now. This will be fixed in the future.
|
||||
|
||||
### Gemini Plugin
|
||||
|
||||
This plugin implements Zero-3 with chunk-based and heterogeneous memory management. It can train large models without much loss in speed. It also does not support local gradient accumulation. More details can be found in [Gemini Doc](../features/zero_with_chunk.md).
|
||||
|
@ -69,4 +68,24 @@ More details can be found in [Pytorch Docs](https://pytorch.org/docs/main/fsdp.h
|
|||
|
||||
{{ autodoc:colossalai.booster.plugin.TorchFSDPPlugin }}
|
||||
|
||||
|
||||
### Hybrid Parallel Plugin
|
||||
|
||||
This plugin implements the combination of various parallel training strategies and optimization tools. The features of HybridParallelPlugin can be generally divided into four parts:
|
||||
|
||||
1. Shardformer: This plugin provides an entrance to Shardformer, which controls model sharding under tensor parallel and pipeline parallel setting. Shardformer also overloads the logic of model's forward/backward process to ensure the smooth working of tp/pp. Also, optimization tools including fused normalization, flash attention (xformers), JIT and sequence parallel are injected into the overloaded forward/backward method by Shardformer.
|
||||
|
||||
2. Mixed Precision Training: Support for fp16/bf16 mixed precision training. More details about its arguments configuration can be found in [Mixed Precision Training Doc](../features/mixed_precision_training_with_booster.md).
|
||||
|
||||
3. Torch DDP: This plugin will automatically adopt Pytorch DDP as data parallel strategy when pipeline parallel and Zero is not used. More details about its arguments configuration can be found in [Pytorch DDP Docs](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel).
|
||||
|
||||
4. Zero: This plugin can adopt Zero 1/2 as data parallel strategy through setting the `zero_stage` argument as 1 or 2 when initializing plugin. Zero 1 is compatible with pipeline parallel strategy, while Zero 2 is not. More details about its argument configuration can be found in [Low Level Zero Plugin](#low-level-zero-plugin).
|
||||
|
||||
> ⚠ When using this plugin, only the subset of Huggingface transformers supported by Shardformer are compatible with tensor parallel, pipeline parallel and optimization tools. Mainstream transformers such as Llama 1, Llama 2, OPT, Bloom, Bert and GPT2 etc. are all supported by Shardformer.
|
||||
|
||||
> ⚠ This plugin only supports sharded checkpointing methods for model/optimizer at present. Unsharded checkpointing methods will be supported in future release.
|
||||
|
||||
{{ autodoc:colossalai.booster.plugin.HybridParallelPlugin }}
|
||||
|
||||
|
||||
<!-- doc-test-command: echo -->
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# booster 使用
|
||||
|
||||
作者: [Mingyan Jiang](https://github.com/jiangmingyan) [Jianghai Chen](https://github.com/CjhHa1)
|
||||
作者: [Mingyan Jiang](https://github.com/jiangmingyan), [Jianghai Chen](https://github.com/CjhHa1), [Baizhou Zhang](https://github.com/Fridge003)
|
||||
|
||||
**预备知识:**
|
||||
|
||||
|
@ -11,17 +11,19 @@
|
|||
|
||||
<!-- update this url-->
|
||||
|
||||
- [使用 booster 训练](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet/README.md)
|
||||
- [使用 booster 训练](https://github.com/hpcaitech/ColossalAI/blob/main/examples/tutorial/new_api/cifar_resnet)
|
||||
|
||||
## 简介
|
||||
|
||||
在我们的新设计中, `colossalai.booster` 代替 `colossalai.initialize` 将特征(例如,模型、优化器、数据加载器)无缝注入您的训练组件中。 使用 booster API, 您可以更友好地将我们的并行策略整合到待训练模型中. 调用 `colossalai.booster` 是您进入训练循环前的基本操作。
|
||||
在我们的新设计中, `colossalai.booster` 代替 `colossalai.initialize` 将特征(例如,模型、优化器、数据加载器)无缝注入到您的训练组件中。 使用 booster API, 您可以更友好地将我们的并行策略整合到待训练模型中. 调用 `colossalai.booster` 是您进入训练流程前的正常操作。
|
||||
在下面的章节中,我们将介绍 `colossalai.booster` 是如何工作的以及使用时我们要注意的细节。
|
||||
|
||||
### Booster 插件
|
||||
|
||||
Booster 插件是管理并行配置的重要组件(eg:gemini 插件封装了 gemini 加速方案)。目前支持的插件如下:
|
||||
|
||||
**_HybridParallelPlugin:_** HybirdParallelPlugin 插件封装了混合并行的加速解决方案。它提供的接口可以在张量并行,流水线并行以及两种数据并行方法(DDP, Zero)间进行任意的组合。
|
||||
|
||||
**_GeminiPlugin:_** GeminiPlugin 插件封装了 gemini 加速解决方案,即基于块内存管理的 ZeRO 优化方案。
|
||||
|
||||
**_TorchDDPPlugin:_** TorchDDPPlugin 插件封装了Pytorch的DDP加速方案,实现了模型级别的数据并行,可以跨多机运行。
|
||||
|
@ -30,6 +32,7 @@ Booster 插件是管理并行配置的重要组件(eg:gemini 插件封装了
|
|||
|
||||
**_TorchFSDPPlugin:_** TorchFSDPPlugin封装了 Pytorch的FSDP加速方案,可以用于零冗余优化器数据并行(ZeroDP)的训练。
|
||||
|
||||
若想了解更多关于插件的用法细节,请参考[Booster 插件](./booster_plugins.md)章节。
|
||||
|
||||
### Booster 接口
|
||||
|
||||
|
@ -39,7 +42,7 @@ Booster 插件是管理并行配置的重要组件(eg:gemini 插件封装了
|
|||
|
||||
## 使用方法及示例
|
||||
|
||||
在使用 colossalai 训练时,首先需要在训练脚本的开头启动分布式环境,并创建需要使用的模型、优化器、损失函数、数据加载器等对象。之后,调用`colossalai.booster` 将特征注入到这些对象中,您就可以使用我们的 booster API 去进行您接下来的训练流程。
|
||||
在使用 colossalai 训练时,首先需要在训练脚本的开头启动分布式环境,并创建需要使用的模型、优化器、损失函数、数据加载器等对象。之后,调用`booster.boost` 将特征注入到这些对象中,您就可以使用我们的 booster API 去进行您接下来的训练流程。
|
||||
|
||||
以下是一个伪代码示例,将展示如何使用我们的 booster API 进行模型训练:
|
||||
|
||||
|
@ -53,15 +56,21 @@ from colossalai.booster import Booster
|
|||
from colossalai.booster.plugin import TorchDDPPlugin
|
||||
|
||||
def train():
|
||||
# launch colossalai
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
|
||||
# create plugin and objects for training
|
||||
plugin = TorchDDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
model = resnet18()
|
||||
criterion = lambda x: x.mean()
|
||||
optimizer = SGD((model.parameters()), lr=0.001)
|
||||
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
|
||||
|
||||
# use booster.boost to wrap the training objects
|
||||
model, optimizer, criterion, _, scheduler = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler)
|
||||
|
||||
# do training as normal, except that the backward should be called by booster
|
||||
x = torch.randn(4, 3, 224, 224)
|
||||
x = x.to('cuda')
|
||||
output = model(x)
|
||||
|
@ -70,14 +79,16 @@ def train():
|
|||
optimizer.clip_grad_by_norm(1.0)
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# checkpointing using booster api
|
||||
save_path = "./model"
|
||||
booster.save_model(model, save_path, True, True, "", 10, use_safetensors=use_safetensors)
|
||||
booster.save_model(model, save_path, shard=True, size_per_shard=10, use_safetensors=True)
|
||||
|
||||
new_model = resnet18()
|
||||
booster.load_model(new_model, save_path)
|
||||
```
|
||||
|
||||
[更多的设计细节请参考](https://github.com/hpcaitech/ColossalAI/discussions/3046)
|
||||
更多的Booster设计细节请参考这一[页面](https://github.com/hpcaitech/ColossalAI/discussions/3046)
|
||||
|
||||
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 booster_api.py -->
|
||||
|
|
|
@ -13,32 +13,32 @@
|
|||
|
||||
{{ autodoc:colossalai.booster.Booster.save_model }}
|
||||
|
||||
模型在保存前必须被 `colossalai.booster.Booster` 加速。 `checkpoint` 是要保存的 checkpoint 的路径。 如果 `shard=False`,它就是文件。 否则, 它就是文件夹。如果 `shard=True`,checkpoint 将以分片方式保存。当 checkpoint 太大而无法保存在单个文件中时,这很有用。我们的分片 checkpoint 格式与 [huggingface/transformers](https://github.com/huggingface/transformers) 兼容。
|
||||
模型在保存前必须被 `colossalai.booster.Booster` 封装。 `checkpoint` 是要保存的 checkpoint 的路径。 如果 `shard=False`,它就是文件。 否则, 它就是文件夹。如果 `shard=True`,checkpoint 将以分片方式保存,在 checkpoint 太大而无法保存在单个文件中时会很实用。我们的分片 checkpoint 格式与 [huggingface/transformers](https://github.com/huggingface/transformers) 兼容,所以用户可以使用huggingface的`from_pretrained`方法从分片checkpoint加载模型。
|
||||
|
||||
{{ autodoc:colossalai.booster.Booster.load_model }}
|
||||
|
||||
模型在加载前必须被 `colossalai.booster.Booster` 加速。它会自动检测 checkpoint 格式,并以相应的方式加载。
|
||||
模型在加载前必须被 `colossalai.booster.Booster` 封装。它会自动检测 checkpoint 格式,并以相应的方式加载。
|
||||
|
||||
## 优化器 Checkpoint
|
||||
|
||||
|
||||
{{ autodoc:colossalai.booster.Booster.save_optimizer }}
|
||||
|
||||
优化器在保存前必须被 `colossalai.booster.Booster` 加速。
|
||||
优化器在保存前必须被 `colossalai.booster.Booster` 封装。
|
||||
|
||||
{{ autodoc:colossalai.booster.Booster.load_optimizer }}
|
||||
|
||||
优化器在加载前必须被 `colossalai.booster.Booster` 加速。
|
||||
优化器在加载前必须被 `colossalai.booster.Booster` 封装。
|
||||
|
||||
## 学习率调度器 Checkpoint
|
||||
|
||||
{{ autodoc:colossalai.booster.Booster.save_lr_scheduler }}
|
||||
|
||||
学习率调度器在保存前必须被 `colossalai.booster.Booster` 加速。 `checkpoint` 是 checkpoint 文件的本地路径.
|
||||
学习率调度器在保存前必须被 `colossalai.booster.Booster` 封装。 `checkpoint` 是 checkpoint 文件的本地路径.
|
||||
|
||||
{{ autodoc:colossalai.booster.Booster.load_lr_scheduler }}
|
||||
|
||||
学习率调度器在加载前必须被 `colossalai.booster.Booster` 加速。 `checkpoint` 是 checkpoint 文件的本地路径.
|
||||
学习率调度器在加载前必须被 `colossalai.booster.Booster` 封装。 `checkpoint` 是 checkpoint 文件的本地路径.
|
||||
|
||||
## Checkpoint 设计
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Booster 插件
|
||||
|
||||
作者: [Hongxin Liu](https://github.com/ver217)
|
||||
作者: [Hongxin Liu](https://github.com/ver217), [Baizhou Zhang](https://github.com/Fridge003)
|
||||
|
||||
**前置教程:**
|
||||
- [Booster API](./booster_api.md)
|
||||
|
@ -11,10 +11,11 @@
|
|||
|
||||
我们现在提供以下插件:
|
||||
|
||||
- [Low Level Zero 插件](#low-level-zero-plugin): 它包装了 `colossalai.zero.low_level.LowLevelZeroOptimizer`,可用于使用 Zero-dp 训练模型。它仅支持 Zero 阶段1和阶段2。
|
||||
- [Gemini 插件](#gemini-plugin): 它包装了 [Gemini](../features/zero_with_chunk.md),Gemini 实现了基于Chunk内存管理和异构内存管理的 Zero-3。
|
||||
- [Torch DDP 插件](#torch-ddp-plugin): 它包装了 `torch.nn.parallel.DistributedDataParallel` 并且可用于使用数据并行训练模型。
|
||||
- [Torch FSDP 插件](#torch-fsdp-plugin): 它包装了 `torch.distributed.fsdp.FullyShardedDataParallel` 并且可用于使用 Zero-dp 训练模型。
|
||||
- [Low Level Zero 插件](#low-level-zero-插件): 它包装了 `colossalai.zero.low_level.LowLevelZeroOptimizer`,可用于使用 Zero-dp 训练模型。它仅支持 Zero 阶段1和阶段2。
|
||||
- [Gemini 插件](#gemini-插件): 它包装了 [Gemini](../features/zero_with_chunk.md),Gemini 实现了基于Chunk内存管理和异构内存管理的 Zero-3。
|
||||
- [Torch DDP 插件](#torch-ddp-插件): 它包装了 `torch.nn.parallel.DistributedDataParallel` 并且可用于使用数据并行训练模型。
|
||||
- [Torch FSDP 插件](#torch-fsdp-插件): 它包装了 `torch.distributed.fsdp.FullyShardedDataParallel` 并且可用于使用 Zero-dp 训练模型。
|
||||
- [Hybrid Pararllel 插件](#hybrid-parallel-插件): 它为Shardformer,流水线管理器,混合精度运算,TorchDDP以及Zero-1/Zero-2功能提供了一个统一且简洁的接口。使用该插件可以简单高效地实现transformer模型在张量并行,流水线并行以及数据并行(DDP, Zero)间任意组合并行训练策略,同时支持多种训练速度和内存的优化工具。有关这些训练策略和优化工具的具体信息将在下一章中阐述。
|
||||
|
||||
更多插件即将推出。
|
||||
|
||||
|
@ -43,8 +44,6 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累
|
|||
|
||||
兼容性问题将在未来修复。
|
||||
|
||||
> ⚠ 该插件现在只能加载自己保存的且具有相同进程数的优化器 Checkpoint。这将在未来得到解决。
|
||||
|
||||
### Gemini 插件
|
||||
|
||||
这个插件实现了基于Chunk内存管理和异构内存管理的 Zero-3。它可以训练大型模型而不会损失太多速度。它也不支持局部梯度累积。更多详细信息,请参阅 [Gemini 文档](../features/zero_with_chunk.md).
|
||||
|
@ -70,4 +69,23 @@ Zero-2 不支持局部梯度累积。如果您坚持使用,虽然可以积累
|
|||
|
||||
{{ autodoc:colossalai.booster.plugin.TorchFSDPPlugin }}
|
||||
|
||||
|
||||
### Hybrid Parallel 插件
|
||||
|
||||
这个插件实现了多种并行训练策略和优化工具的组合。Hybrid Parallel插件支持的功能大致可以被分为以下四个部分:
|
||||
|
||||
1. Shardformer: Shardformer负责在张量并行以及流水线并行下切分模型的逻辑,以及前向/后向方法的重载,这个插件为Shardformer功能提供了一个简单易用的接口。与此同时,Shardformer还负责将包括fused normalization, flash attention (xformers), JIT和序列并行在内的各类优化工具融入重载后的前向/后向方法。
|
||||
|
||||
2. 混合精度训练:插件支持fp16/bf16的混合精度训练。更多关于混合精度训练的参数配置的详细信息请参考 [混合精度训练文档](../features/mixed_precision_training_with_booster.md)。
|
||||
|
||||
3. Torch DDP: 当流水线并行和Zero不被使用的时候,插件会自动采用Pytorch DDP作为数据并行的策略。更多关于Torch DDP的参数配置的详细信息请参考 [Pytorch DDP 文档](https://pytorch.org/docs/main/generated/torch.nn.parallel.DistributedDataParallel.html#torch.nn.parallel.DistributedDataParallel)。
|
||||
|
||||
4. Zero: 在初始化插件的时候,可以通过将`zero_stage`参数设置为1或2来让插件采用Zero 1/2作为数据并行的策略。Zero 1可以和流水线并行策略同时使用, 而Zero 2则不可以和流水线并行策略同时使用。更多关于Zero的参数配置的详细信息请参考 [Low Level Zero 插件](#low-level-zero-插件).
|
||||
|
||||
> ⚠ 在使用该插件的时候, 只有支持Shardformer的部分Huggingface transformers模型才能够使用张量并行、流水线并行以及优化工具。Llama 1、Llama 2、OPT、Bloom、Bert以及GPT2等主流transformers模型均已支持Shardformer。
|
||||
|
||||
> ⚠ 该插件当前只对模型和优化器支持分片的checkpoint方法。不分片的checkpoint方法会在未来的版本中被支持。
|
||||
|
||||
{{ autodoc:colossalai.booster.plugin.HybridParallelPlugin }}
|
||||
|
||||
<!-- doc-test-command: echo -->
|
||||
|
|
Loading…
Reference in New Issue