mirror of https://github.com/hpcaitech/ColossalAI
[doc] add lazy init docs (#4808)
parent
a22706337a
commit
da15fdb9ca
|
@ -472,30 +472,11 @@ class LazyTensor(torch.Tensor):
|
|||
class LazyInitContext:
|
||||
"""Context manager for lazy initialization. Enables initializing the model without allocating real memory.
|
||||
|
||||
Usage:
|
||||
1. The model is initialized, but no real memory is allocated.
|
||||
>>> ctx = LazyInitContext()
|
||||
>>> with ctx:
|
||||
>>> model = MyModel().cuda()
|
||||
|
||||
2. The model is initialized with ``MetaTensor`` as weights, but still no real memory is allocated.
|
||||
>>> with ctx.traceable(model):
|
||||
>>> gm = symbolic_trace(model, meta_args=meta_args)
|
||||
>>> # Solve the execution strategy and apply the strategy to the model
|
||||
>>> strategy = StrategyAndSpec()
|
||||
|
||||
3. The model is initialized with ``torch.Tensor`` as weights, and real memory is allocated. (single device)
|
||||
>>> model = ctx.materialize(model)
|
||||
|
||||
3. The model is initialized with sharded ``torch.Tensor`` as weights, and real memory is allocated. (distributed scenario)
|
||||
>>> model = apply_strategy_to_all_params(model, strategy)
|
||||
>>> model = ctx.distribute(model)
|
||||
|
||||
Warnings:
|
||||
This API is still experimental and further modifications can be made to it.
|
||||
For example:
|
||||
1. Quantization strategies can be applied before allocating real memory.
|
||||
2. Lazy initialization seems slower than normal initialization.
|
||||
Args:
|
||||
tensor_cls (Union[_MyTensor, LazyTensor], optional): This is only for test. Defaults to LazyTensor.
|
||||
default_device (Optional[Union[torch.device, str, int]], optional): Defalt device for initialization.
|
||||
If it's cuda, initilization will be accelerated, but cuda memory will be allocated. By default, it's cpu.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
_replaced: bool = False
|
||||
|
|
|
@ -55,6 +55,7 @@
|
|||
},
|
||||
"features/pipeline_parallel",
|
||||
"features/nvme_offload",
|
||||
"features/lazy_init",
|
||||
"features/cluster_utils"
|
||||
]
|
||||
},
|
||||
|
|
|
@ -32,6 +32,8 @@ Plugin is an important component that manages parallel configuration (eg: The ge
|
|||
|
||||
More details about usages of each plugin can be found in chapter [Booster Plugins](./booster_plugins.md).
|
||||
|
||||
Some plugins support lazy initialization, which can be used to save memory when initializating large models. For more details, please see [Lazy Initialization](../features/lazy_init.md).
|
||||
|
||||
### API of booster
|
||||
|
||||
{{ autodoc:colossalai.booster.Booster }}
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
# Lazy initialization
|
||||
|
||||
Author: [Hongxiu Liu](https://github.com/ver217)
|
||||
|
||||
**Prerequisite:**
|
||||
- [Train with booster](../basics/booster_api.md)
|
||||
|
||||
## Introduction
|
||||
|
||||
Lazy initialization defers model initialization. It saves memory when initializing large models.
|
||||
|
||||
If your model has `N` billion parameters and your memory (or GPU memory) is `M` GB, we recommend you use lazy initialization when `4N >= M`. Otherwise, it is optional.
|
||||
|
||||
## Usage
|
||||
|
||||
Lazy initialization must be used with booster.
|
||||
|
||||
### API reference
|
||||
|
||||
{{ autodoc:colossalai.lazy.LazyInitContext }}
|
||||
|
||||
### Example
|
||||
|
||||
```python
|
||||
import colossalai
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
|
||||
from transformers import LlamaForCausalLM, LlamaConfig, BertForPreTraining
|
||||
|
||||
colossalai.launch({})
|
||||
plugin = GeminiPlugin()
|
||||
booster = Booster(plugin)
|
||||
|
||||
# 1. Initialize model from scratch
|
||||
# Initialization on cuda will accelerate the initialization process but take more GPU memory.
|
||||
with LazyInitContext(default_device="cuda"):
|
||||
model = LlamaForCausalLM(LlamaConfig(hidden_size=64, intermediate_size=172, num_hidden_layers=4, num_attention_heads=4))
|
||||
model, *_ = booster.boost(model)
|
||||
|
||||
# 2. Initialize model from pretrained
|
||||
with LazyInitContext():
|
||||
model = BertForPreTraining.from_pretrained("prajjwal1/bert-tiny")
|
||||
model, *_ = booster.boost(model)
|
||||
```
|
||||
|
||||
> ⚠️ Lazy initialization from pretrained is supported for colossalai>0.3.3 or main branch.
|
||||
|
||||
## Limitations
|
||||
|
||||
As we claimed, lazy initialization must be used with booster. And only several plugins support it.
|
||||
|
||||
| Plugin | Supported | Remarks |
|
||||
|-----------------|-----------|--------------|
|
||||
| Gemini | Yes | |
|
||||
| Hybrid Parallel | Yes | |
|
||||
| Low Level Zero | No | No need |
|
||||
| Torch DDP | No | Incompatible |
|
||||
| Torch FSDP | No | Incompatible |
|
||||
|
||||
Not all models can be lazily initialized. In some cases, a part of parameters/buffers may be early initialized. But don't worry, this part usually takes a small proportion of the whole model.
|
||||
|
||||
And some models are not supported at all which will raise an error. We tested models in torchvision, diffusers, timm, transformers, torchaudio and torchrec. Below models are not supported:
|
||||
|
||||
| Model | Category |
|
||||
|-------------------------------|--------------|
|
||||
| wav2vec2_base | torchaudio |
|
||||
| hubert_base | torchaudio |
|
||||
| ViTModel | transformers |
|
||||
| ViTForMaskedImageModeling | transformers |
|
||||
| ViTForImageClassification | transformers |
|
||||
| Blip2Model | transformers |
|
||||
| Blip2ForConditionalGeneration | transformers |
|
||||
|
||||
<!-- doc-test-command: torchrun --standalone --nproc_per_node=2 lazy_iniy.py -->
|
|
@ -35,6 +35,8 @@ Booster 插件是管理并行配置的重要组件(eg:gemini 插件封装了
|
|||
|
||||
若想了解更多关于插件的用法细节,请参考[Booster 插件](./booster_plugins.md)章节。
|
||||
|
||||
有一些插件支持懒惰初始化,它能节省初始化大模型时的内存占用。详情请参考[懒惰初始化](../features/lazy_init.md)。
|
||||
|
||||
### Booster 接口
|
||||
|
||||
<!--TODO: update autodoc -->
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
# 懒惰初始化
|
||||
|
||||
作者: [Hongxiu Liu](https://github.com/ver217)
|
||||
|
||||
**前置教程:**
|
||||
- [Train with booster](../basics/booster_api.md)
|
||||
|
||||
## 简介
|
||||
|
||||
懒惰初始化延迟了模型的初始化。它能够节省在大模型初始化时的内存占用。
|
||||
|
||||
如果你的模型有 `N` 十亿个参数并且你的内存(或显存)为 `M` GB, 我们推荐您在 `4N >= M` 时使用懒惰初始化。否则,懒惰初始化不是必须的。
|
||||
|
||||
## 使用
|
||||
|
||||
懒惰初始化必须与 booster 一起使用。
|
||||
|
||||
### API 参考
|
||||
|
||||
{{ autodoc:colossalai.lazy.LazyInitContext }}
|
||||
|
||||
### 例子
|
||||
|
||||
```python
|
||||
import colossalai
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
|
||||
from transformers import LlamaForCausalLM, LlamaConfig, BertForPreTraining
|
||||
|
||||
colossalai.launch({})
|
||||
plugin = GeminiPlugin()
|
||||
booster = Booster(plugin)
|
||||
|
||||
# 1. Initialize model from scratch
|
||||
# Initialization on cuda will accelerate the initialization process but take more GPU memory.
|
||||
with LazyInitContext(default_device="cuda"):
|
||||
model = LlamaForCausalLM(LlamaConfig(hidden_size=64, intermediate_size=172, num_hidden_layers=4, num_attention_heads=4))
|
||||
model, *_ = booster.boost(model)
|
||||
|
||||
# 2. Initialize model from pretrained
|
||||
with LazyInitContext():
|
||||
model = BertForPreTraining.from_pretrained("prajjwal1/bert-tiny")
|
||||
model, *_ = booster.boost(model)
|
||||
```
|
||||
|
||||
> ⚠️ 使用懒惰初始化加载预训练模型在 colossalai>0.3.3 或主分支上支持。
|
||||
|
||||
## 限制
|
||||
|
||||
我们提到,懒惰初始化必须与 booster 一起使用。只有几个插件支持它。
|
||||
|
||||
| 插件 | 支持情况 | 备注 |
|
||||
|-----------------|---------|--------|
|
||||
| Gemini | 是 | |
|
||||
| Hybrid Parallel | 是 | |
|
||||
| Low Level Zero | 否 | 不需要 |
|
||||
| Torch DDP | 否 | 不兼容 |
|
||||
| Torch FSDP | 否 | 不兼容 |
|
||||
|
||||
不是所有的模型都可以懒惰初始化。在某些情况下,一部分参数/缓冲区可能会被提前初始化。但是不用担心,这部分通常只占整个模型的一小部分。
|
||||
|
||||
并且一些模型完全不支持,会引发错误。我们测试了 torchvision, diffusers, timm, transformers, torchaudio 和 torchrec 中的模型。以下模型不受支持:
|
||||
|
||||
| 模型 | 分类 |
|
||||
|-------------------------------|--------------|
|
||||
| wav2vec2_base | torchaudio |
|
||||
| hubert_base | torchaudio |
|
||||
| ViTModel | transformers |
|
||||
| ViTForMaskedImageModeling | transformers |
|
||||
| ViTForImageClassification | transformers |
|
||||
| Blip2Model | transformers |
|
||||
| Blip2ForConditionalGeneration | transformers |
|
||||
|
||||
<!-- doc-test-command: torchrun --standalone --nproc_per_node=2 lazy_iniy.py -->
|
Loading…
Reference in New Issue