[doc] update document of zero with chunk. (#3855)

* [doc] fix title of mixed precision

* [doc]update document of zero with chunk

* [doc] update document of zero with chunk, fix

* [doc] update document of zero with chunk, fix

* [doc] update document of zero with chunk, fix

* [doc] update document of zero with chunk, add doc test

* [doc] update document of zero with chunk, add doc test

* [doc] update document of zero with chunk, fix installation

* [doc] update document of zero with chunk, fix zero with chunk doc

* [doc] update document of zero with chunk, fix zero with chunk doc
pull/3847/head^2
jiangmingyan 2023-05-30 18:41:56 +08:00 committed by GitHub
parent 5f79008c4a
commit 281b33f362
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 43 additions and 77 deletions

View File

@ -3,7 +3,7 @@
Author: [Hongxiu Liu](https://github.com/ver217), [Jiarui Fang](https://github.com/feifeibear), [Zijian Ye](https://github.com/ZijianYY)
**Prerequisite:**
- [Define Your Configuration](../basics/define_your_config.md)
- [Train with booster](../basics/booster_api.md)
**Example Code**
@ -97,6 +97,7 @@ For simplicity, we just use randomly generated data here.
First we only need to import `GPT2LMHeadModel` from `Huggingface transformers` to define our model, which does not require users to define or modify the model, so that users can use it more conveniently.
Define a GPT model:
```python
class GPTLMModel(nn.Module):
@ -182,34 +183,6 @@ def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
```
Define a model which uses Gemini + ZeRO DDP:
```python
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"):
cai_version = colossalai.__version__
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placement_policy,
pin_memory=True,
search_range_mb=32)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(placement_policy))
model = ZeroDDP(model, gemini_manager)
else:
raise NotImplemented(f"CAI version {cai_version} is not supported")
return model
```
As we pre-train GPT in this example, we just use a simple language model loss.
Write a function to get random inputs:
```python
@ -219,9 +192,15 @@ def get_data(batch_size, seq_len, vocab_size):
return input_ids, attention_mask
```
Finally, we can define our training loop:
Finally, we define a model which uses Gemini + ZeRO DDP and define our training loop, As we pre-train GPT in this example, we just use a simple language model loss:
```python
from torch.optim import Adam
from colossalai.booster import Booster
from colossalai.zero import ColoInitContext
from colossalai.booster.plugin import GeminiPlugin
def main():
args = parse_args()
BATCH_SIZE = 8
@ -232,22 +211,23 @@ def main():
# build criterion
criterion = GPTLMLoss()
optimizer = Adam(model.parameters(), lr=0.001)
torch.manual_seed(123)
default_pg = ProcessGroup(tp_degree=args.tp_degree)
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
default_dist_spec = ShardSpec([-1], [args.tp_degree])
# build GPT model
with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg):
model = gpt2_medium(checkpoint=True)
pg = default_pg
# Tensor Parallelism (TP)
tensor_parallelize(model, pg)
# Gemini + ZeRO DP, Note it must be used after TP
model = gemini_zero_dpp(model, pg, args.placement)
# build optimizer
optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5)
numel = sum([p.numel() for p in model.parameters()])
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
plugin = GeminiPlugin(placement_policy='cuda', max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
torch.cuda.synchronize()
model.train()
for n in range(NUM_STEPS):
@ -256,10 +236,12 @@ def main():
optimizer.zero_grad()
outputs = model(input_ids, attn_mask)
loss = criterion(outputs, input_ids)
optimizer.backward(loss)
booster.backward(loss, optimizer)
optimizer.step()
torch.cuda.synchronize()
```
> ⚠️ Note: If you want to use the Gemini module, please do not use the [Gradient Accumulation](../features/gradient_accumulation.md) we mentioned before。
The complete example can be found on [Train GPT with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt).
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 zero_with_chunk.py -->

View File

@ -1,4 +1,4 @@
# 梯度累积 (新版本)
# 梯度累积 (新版本)
作者: [Mingyan Jiang](https://github.com/jiangmingyan)

View File

@ -1,4 +1,4 @@
# 自动混合精度训练 (新版本)
# 自动混合精度训练 (新版本)
作者: [Mingyan Jiang](https://github.com/jiangmingyan)

View File

@ -4,7 +4,7 @@
**前置教程:**
- [定义配置文件](../basics/define_your_config.md)
- [booster使用](../basics/booster_api.md)
**示例代码**
@ -97,6 +97,8 @@ optimizer.step()
首先我们只需要引入`Huggingface transformers` 的 `GPT2LMHeadModel`来定义我们的模型,不需要用户进行模型的定义与修改,方便用户使用。
定义GPT模型
```python
class GPTLMModel(nn.Module):
@ -182,34 +184,6 @@ def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
```
定义一个使用 Gemini + ZeRO DDP 的模型:
```python
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placement_policy: str = "auto"):
cai_version = colossalai.__version__
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placement_policy,
pin_memory=True,
search_range_mb=32)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(placement_policy))
model = ZeroDDP(model, gemini_manager)
else:
raise NotImplemented(f"CAI version {cai_version} is not supported")
return model
```
由于我们在这个例子中对GPT进行预训练因此只使用了一个简单的语言模型损失函数。
写一个获得随机输入的函数:
```python
@ -219,9 +193,16 @@ def get_data(batch_size, seq_len, vocab_size):
return input_ids, attention_mask
```
最后,我们可以定义我们的训练循环:
最后使用booster注入 Gemini + ZeRO DDP 特性, 并定义训练循环。由于我们在这个例子中对GPT进行预训练因此只使用了一个简单的语言模型损失函数
```python
from torch.optim import Adam
from colossalai.booster import Booster
from colossalai.zero import ColoInitContext
from colossalai.booster.plugin import GeminiPlugin
def main():
args = parse_args()
BATCH_SIZE = 8
@ -232,22 +213,23 @@ def main():
# build criterion
criterion = GPTLMLoss()
optimizer = Adam(model.parameters(), lr=0.001)
torch.manual_seed(123)
default_pg = ProcessGroup(tp_degree=args.tp_degree)
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
default_dist_spec = ShardSpec([-1], [args.tp_degree])
# build GPT model
with ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg):
model = gpt2_medium(checkpoint=True)
pg = default_pg
# Tensor Parallelism (TP)
tensor_parallelize(model, pg)
# Gemini + ZeRO DP, Note it must be used after TP
model = gemini_zero_dpp(model, pg, args.placement)
# build optimizer
optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5)
numel = sum([p.numel() for p in model.parameters()])
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
plugin = GeminiPlugin(placement_policy='cuda', max_norm=1.0, initial_scale=2**5)
booster = Booster(plugin=plugin)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
torch.cuda.synchronize()
model.train()
for n in range(NUM_STEPS):
@ -256,10 +238,12 @@ def main():
optimizer.zero_grad()
outputs = model(input_ids, attn_mask)
loss = criterion(outputs, input_ids)
optimizer.backward(loss)
booster.backward(loss, optimizer)
optimizer.step()
torch.cuda.synchronize()
```
> ⚠️ 注意如果你使用Gemini模块的话请不要使用我们之前提到过的[梯度累加](../features/gradient_accumulation.md)。
完整的例子代码可以在 [Train GPT with Colossal-AI](https://github.com/hpcaitech/ColossalAI/tree/main/examples/language/gpt). 获得。
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 zero_with_chunk.py -->

View File

@ -47,7 +47,7 @@ CUDA_EXT=1 pip install .
pip install .
```
如果您在使用CUDA 10.2您仍然可以从源码安装ColossalA。但是您需要手动下载cub库并将其复制到相应的目录。
如果您在使用CUDA 10.2您仍然可以从源码安装ColossalAI。但是您需要手动下载cub库并将其复制到相应的目录。
```bash
# clone the repository