[doc] add lazy init docs (#4808)

pull/4815/head
Hongxin Liu 2023-09-27 10:24:04 +08:00 committed by GitHub
parent a22706337a
commit da15fdb9ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 162 additions and 24 deletions

View File

@ -472,30 +472,11 @@ class LazyTensor(torch.Tensor):
class LazyInitContext:
"""Context manager for lazy initialization. Enables initializing the model without allocating real memory.
Usage:
1. The model is initialized, but no real memory is allocated.
>>> ctx = LazyInitContext()
>>> with ctx:
>>> model = MyModel().cuda()
2. The model is initialized with ``MetaTensor`` as weights, but still no real memory is allocated.
>>> with ctx.traceable(model):
>>> gm = symbolic_trace(model, meta_args=meta_args)
>>> # Solve the execution strategy and apply the strategy to the model
>>> strategy = StrategyAndSpec()
3. The model is initialized with ``torch.Tensor`` as weights, and real memory is allocated. (single device)
>>> model = ctx.materialize(model)
3. The model is initialized with sharded ``torch.Tensor`` as weights, and real memory is allocated. (distributed scenario)
>>> model = apply_strategy_to_all_params(model, strategy)
>>> model = ctx.distribute(model)
Warnings:
This API is still experimental and further modifications can be made to it.
For example:
1. Quantization strategies can be applied before allocating real memory.
2. Lazy initialization seems slower than normal initialization.
Args:
tensor_cls (Union[_MyTensor, LazyTensor], optional): This is only for test. Defaults to LazyTensor.
default_device (Optional[Union[torch.device, str, int]], optional): Defalt device for initialization.
If it's cuda, initilization will be accelerated, but cuda memory will be allocated. By default, it's cpu.
Defaults to None.
"""
_replaced: bool = False

View File

@ -55,6 +55,7 @@
},
"features/pipeline_parallel",
"features/nvme_offload",
"features/lazy_init",
"features/cluster_utils"
]
},

View File

@ -32,6 +32,8 @@ Plugin is an important component that manages parallel configuration (eg: The ge
More details about usages of each plugin can be found in chapter [Booster Plugins](./booster_plugins.md).
Some plugins support lazy initialization, which can be used to save memory when initializating large models. For more details, please see [Lazy Initialization](../features/lazy_init.md).
### API of booster
{{ autodoc:colossalai.booster.Booster }}

View File

@ -0,0 +1,76 @@
# Lazy initialization
Author: [Hongxiu Liu](https://github.com/ver217)
**Prerequisite:**
- [Train with booster](../basics/booster_api.md)
## Introduction
Lazy initialization defers model initialization. It saves memory when initializing large models.
If your model has `N` billion parameters and your memory (or GPU memory) is `M` GB, we recommend you use lazy initialization when `4N >= M`. Otherwise, it is optional.
## Usage
Lazy initialization must be used with booster.
### API reference
{{ autodoc:colossalai.lazy.LazyInitContext }}
### Example
```python
import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from transformers import LlamaForCausalLM, LlamaConfig, BertForPreTraining
colossalai.launch({})
plugin = GeminiPlugin()
booster = Booster(plugin)
# 1. Initialize model from scratch
# Initialization on cuda will accelerate the initialization process but take more GPU memory.
with LazyInitContext(default_device="cuda"):
model = LlamaForCausalLM(LlamaConfig(hidden_size=64, intermediate_size=172, num_hidden_layers=4, num_attention_heads=4))
model, *_ = booster.boost(model)
# 2. Initialize model from pretrained
with LazyInitContext():
model = BertForPreTraining.from_pretrained("prajjwal1/bert-tiny")
model, *_ = booster.boost(model)
```
> ⚠️ Lazy initialization from pretrained is supported for colossalai>0.3.3 or main branch.
## Limitations
As we claimed, lazy initialization must be used with booster. And only several plugins support it.
| Plugin | Supported | Remarks |
|-----------------|-----------|--------------|
| Gemini | Yes | |
| Hybrid Parallel | Yes | |
| Low Level Zero | No | No need |
| Torch DDP | No | Incompatible |
| Torch FSDP | No | Incompatible |
Not all models can be lazily initialized. In some cases, a part of parameters/buffers may be early initialized. But don't worry, this part usually takes a small proportion of the whole model.
And some models are not supported at all which will raise an error. We tested models in torchvision, diffusers, timm, transformers, torchaudio and torchrec. Below models are not supported:
| Model | Category |
|-------------------------------|--------------|
| wav2vec2_base | torchaudio |
| hubert_base | torchaudio |
| ViTModel | transformers |
| ViTForMaskedImageModeling | transformers |
| ViTForImageClassification | transformers |
| Blip2Model | transformers |
| Blip2ForConditionalGeneration | transformers |
<!-- doc-test-command: torchrun --standalone --nproc_per_node=2 lazy_iniy.py -->

View File

@ -35,6 +35,8 @@ Booster 插件是管理并行配置的重要组件eggemini 插件封装了
若想了解更多关于插件的用法细节,请参考[Booster 插件](./booster_plugins.md)章节。
有一些插件支持懒惰初始化,它能节省初始化大模型时的内存占用。详情请参考[懒惰初始化](../features/lazy_init.md)。
### Booster 接口
<!--TODO: update autodoc -->

View File

@ -0,0 +1,76 @@
# 懒惰初始化
作者: [Hongxiu Liu](https://github.com/ver217)
**前置教程:**
- [Train with booster](../basics/booster_api.md)
## 简介
懒惰初始化延迟了模型的初始化。它能够节省在大模型初始化时的内存占用。
如果你的模型有 `N` 十亿个参数并且你的内存(或显存)为 `M` GB, 我们推荐您在 `4N >= M` 时使用懒惰初始化。否则,懒惰初始化不是必须的。
## 使用
懒惰初始化必须与 booster 一起使用。
### API 参考
{{ autodoc:colossalai.lazy.LazyInitContext }}
### 例子
```python
import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin
from transformers import LlamaForCausalLM, LlamaConfig, BertForPreTraining
colossalai.launch({})
plugin = GeminiPlugin()
booster = Booster(plugin)
# 1. Initialize model from scratch
# Initialization on cuda will accelerate the initialization process but take more GPU memory.
with LazyInitContext(default_device="cuda"):
model = LlamaForCausalLM(LlamaConfig(hidden_size=64, intermediate_size=172, num_hidden_layers=4, num_attention_heads=4))
model, *_ = booster.boost(model)
# 2. Initialize model from pretrained
with LazyInitContext():
model = BertForPreTraining.from_pretrained("prajjwal1/bert-tiny")
model, *_ = booster.boost(model)
```
> ⚠️ 使用懒惰初始化加载预训练模型在 colossalai>0.3.3 或主分支上支持。
## 限制
我们提到,懒惰初始化必须与 booster 一起使用。只有几个插件支持它。
| 插件 | 支持情况 | 备注 |
|-----------------|---------|--------|
| Gemini | 是 | |
| Hybrid Parallel | 是 | |
| Low Level Zero | 否 | 不需要 |
| Torch DDP | 否 | 不兼容 |
| Torch FSDP | 否 | 不兼容 |
不是所有的模型都可以懒惰初始化。在某些情况下,一部分参数/缓冲区可能会被提前初始化。但是不用担心,这部分通常只占整个模型的一小部分。
并且一些模型完全不支持,会引发错误。我们测试了 torchvision, diffusers, timm, transformers, torchaudio 和 torchrec 中的模型。以下模型不受支持:
| 模型 | 分类 |
|-------------------------------|--------------|
| wav2vec2_base | torchaudio |
| hubert_base | torchaudio |
| ViTModel | transformers |
| ViTForMaskedImageModeling | transformers |
| ViTForImageClassification | transformers |
| Blip2Model | transformers |
| Blip2ForConditionalGeneration | transformers |
<!-- doc-test-command: torchrun --standalone --nproc_per_node=2 lazy_iniy.py -->