Browse Source

[doc] update nvme offload doc (#3014)

* [doc] update nvme offload doc

* [doc] add doc testing cmd and requirements

* [doc] add api reference

* [doc] add dependencies
pull/3056/head
ver217 2 years ago committed by GitHub
parent
commit
378d827c6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      docs/requirements-doc-test.txt
  2. 214
      docs/source/en/features/nvme_offload.md
  3. 202
      docs/source/zh-Hans/features/nvme_offload.md

4
docs/requirements-doc-test.txt

@ -1,2 +1,6 @@
colossalai
torch
packaging
tensornvme
psutil
transformers

214
docs/source/en/features/nvme_offload.md

@ -1,3 +1,4 @@
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 nvme_offload.py -->
# NVMe offload
Author: Hongxin Liu
@ -36,12 +37,225 @@ pip install tensornvme
We implement NVMe offload of optimizer states for Adam ([CPUAdam](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.nn.optimizer.cpu_adam.html) and [HybridAdam](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.nn.optimizer.hybrid_adam.html)).
<!--- doc-test-ignore-start -->
```python
from colossalai.nn.optimizer import CPUAdam, HybridAdam
optimizer = HybridAdam(model.parameters(), lr=1e-3, nvme_offload_fraction=1.0, nvme_offload_dir='./')
```
<!--- doc-test-ignore-end -->
`nvme_offload_fraction` is the fraction of optimizer states to be offloaded to NVMe. `nvme_offload_dir` is the directory to save NVMe offload files. If `nvme_offload_dir` is `None`, a random temporary directory will be used.
It's compatible with all parallel methods in ColossalAI.
> ⚠ It only offloads optimizer states on CPU. This means it only affects CPU training or Zero/Gemini with offloading.
## Exampls
Let's start from two simple examples -- training GPT with different methods. These examples relies on `transformers`.
We should install denpendencies first:
```shell
pip install psutil transformers
```
First, we import essential packages and modules:
```python
import os
import time
from typing import Dict, Optional
import psutil
import torch
import torch.nn as nn
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.utils.model.colo_init_context import ColoInitContext
```
Then we define a loss function:
```python
class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
```
And we define some utility functions, which generates random data, computes the number of paramters of a model and get memory usage of current process:
```python
def get_data(batch_size: int, seq_len: int,
vocab_size: int, device: Optional[str] = None) -> Dict[str, torch.Tensor]:
device = torch.cuda.current_device() if device is None else device
input_ids = torch.randint(vocab_size, (batch_size, seq_len),
device=device)
attn_mask = torch.ones_like(input_ids)
return dict(input_ids=input_ids, attention_mask=attn_mask)
def get_model_numel(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def get_mem_usage() -> int:
proc = psutil.Process(os.getpid())
return proc.memory_info().rss
```
We first try to train GPT model on CPU:
```python
def train_cpu(nvme_offload_fraction: float = 0.0):
config = GPT2Config()
model = GPT2LMHeadModel(config)
criterion = GPTLMLoss()
optimizer = HybridAdam(model.parameters(), nvme_offload_fraction=nvme_offload_fraction)
print(f'Model numel: {get_model_numel(model) / 1024**3:.3f} B')
start = time.time()
for step in range(3):
data = get_data(4, 128, config.vocab_size, device='cpu')
outputs = model(**data)
loss = criterion(outputs.logits, data['input_ids'])
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f'[{step}] loss: {loss.item():.3f}')
print(f'Time: {time.time() - start:.3f} s')
print(f'Mem usage: {get_mem_usage() / 1024**2:.3f} MB')
```
Run without NVME offload:
```python
train_cpu(0.0)
```
We may get below output:
```
Model numel: 0.116 B
[0] loss: 10.953
[1] loss: 10.974
[2] loss: 10.965
Time: 7.739 s
Mem usage: 5966.445 MB
```
And then run with (full) NVME offload:
```python
train_cpu(1.0)
```
We may get:
```
Model numel: 0.116 B
[0] loss: 10.951
[1] loss: 10.994
[2] loss: 10.984
Time: 8.527 s
Mem usage: 4968.016 MB
```
For GPT2-S, which has 0.116 billion parameters, its optimizer states take about 0.928 GB memory. And NVME offload saves about 998 MB memory, which meets our expectations.
Then we can train GPT model with Gemini. The placement policy of Gemini should be `"auto"`, `"cpu"` or `"const"`.
```python
def train_gemini_cpu(nvme_offload_fraction: float = 0.0):
colossalai.launch_from_torch({})
config = GPT2Config()
with ColoInitContext(device=torch.cuda.current_device()):
model = GPT2LMHeadModel(config)
criterion = GPTLMLoss()
optimizer = HybridAdam(model.parameters(), nvme_offload_fraction=nvme_offload_fraction)
print(f'Model numel: {get_model_numel(model) / 1024**3:.3f} B')
gemini_config = dict(strict_ddp_mode=True, device=torch.cuda.current_device(),
placement_policy='cpu', pin_memory=True, hidden_dim=config.n_embd)
model = zero_model_wrapper(model, zero_stage=3, gemini_config=gemini_config)
optimizer = zero_optim_wrapper(model, optimizer, initial_scale=2**5)
start = time.time()
for step in range(3):
data = get_data(4, 128, config.vocab_size)
outputs = model(**data)
loss = criterion(outputs.logits, data['input_ids'])
optimizer.backward(loss)
optimizer.step()
optimizer.zero_grad()
print(f'[{step}] loss: {loss.item():.3f}')
print(f'Time: {time.time() - start:.3f} s')
print(f'Mem usage: {get_mem_usage() / 1024**2:.3f} MB')
```
Run without NVME offload:
```python
train_gemini_cpu(0.0)
```
We may get:
```
Model numel: 0.116 B
searching chunk configuration is completed in 0.27 s.
used number: 118.68 MB, wasted number: 0.75 MB
total wasted percentage is 0.63%
[0] loss: 10.953
[1] loss: 10.938
[2] loss: 10.969
Time: 2.997 s
Mem usage: 5592.227 MB
```
And run with (full) NVME offload:
```python
train_gemini_cpu(1.0)
```
We may get:
```
Model numel: 0.116 B
searching chunk configuration is completed in 0.27 s.
used number: 118.68 MB, wasted number: 0.75 MB
total wasted percentage is 0.63%
[0] loss: 10.953
[1] loss: 10.938
[2] loss: 10.969
Time: 3.691 s
Mem usage: 5298.344 MB
```
NVME offload saves about 294 MB memory. Note that enabling `pin_memory` of Gemini can accelerate training but increase memory usage. So this result also meets our expectation. If we disable `pin_memory`, we can aslo observe a memory usage drop about 900 MB.
## API Reference
{{ autodoc:colossalai.nn.optimizer.HybridAdam }}
{{ autodoc:colossalai.nn.optimizer.CPUAdam }}

202
docs/source/zh-Hans/features/nvme_offload.md

@ -1,3 +1,4 @@
<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 nvme_offload.py -->
# NVMe offload
作者: Hongxin Liu
@ -36,12 +37,213 @@ pip install tensornvme
我们为 Adam ([CPUAdam](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.nn.optimizer.cpu_adam.html) 和 [HybridAdam](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.nn.optimizer.hybrid_adam.html)) 实现了优化器状态的 NVMe offload。
<!--- doc-test-ignore-start -->
```python
from colossalai.nn.optimizer import CPUAdam, HybridAdam
optimizer = HybridAdam(model.parameters(), lr=1e-3, nvme_offload_fraction=1.0, nvme_offload_dir='./')
```
<!--- doc-test-ignore-end -->
`nvme_offload_fraction` 是要 offload 到 NVMe 的优化器状态的比例。 `nvme_offload_dir` 是保存 NVMe offload 文件的目录。如果 `nvme_offload_dir``None`,将使用随机临时目录。
它与 ColossalAI 中的所有并行方法兼容。
> ⚠ 它只会卸载在 CPU 上的优化器状态。这意味着它只会影响 CPU 训练或者使用卸载的 Zero/Gemini。
## Exampls
Let's start from two simple examples -- training GPT with different methods. These examples relies on `transformers`.
首先让我们从两个简单的例子开始 -- 用不同的方法训练 GPT。这些例子依赖`transformers`。
我们首先应该安装依赖:
```shell
pip install psutil transformers
```
首先,我们导入必要的包和模块:
```python
import os
import time
from typing import Dict, Optional
import psutil
import torch
import torch.nn as nn
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
from colossalai.utils.model.colo_init_context import ColoInitContext
```
然后我们定义一个损失函数:
```python
class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1))
```
我们定义一些工具函数,用来生成随机数据、计算模型参数量和获取当前进程内存占用:
```python
def get_data(batch_size: int, seq_len: int,
vocab_size: int, device: Optional[str] = None) -> Dict[str, torch.Tensor]:
device = torch.cuda.current_device() if device is None else device
input_ids = torch.randint(vocab_size, (batch_size, seq_len),
device=device)
attn_mask = torch.ones_like(input_ids)
return dict(input_ids=input_ids, attention_mask=attn_mask)
def get_model_numel(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def get_mem_usage() -> int:
proc = psutil.Process(os.getpid())
return proc.memory_info().rss
```
我们首先尝试在 CPU 上训练 GPT 模型:
```python
def train_cpu(nvme_offload_fraction: float = 0.0):
config = GPT2Config()
model = GPT2LMHeadModel(config)
criterion = GPTLMLoss()
optimizer = HybridAdam(model.parameters(), nvme_offload_fraction=nvme_offload_fraction)
print(f'Model numel: {get_model_numel(model) / 1024**3:.3f} B')
start = time.time()
for step in range(3):
data = get_data(4, 128, config.vocab_size, device='cpu')
outputs = model(**data)
loss = criterion(outputs.logits, data['input_ids'])
loss.backward()
optimizer.step()
optimizer.zero_grad()
print(f'[{step}] loss: {loss.item():.3f}')
print(f'Time: {time.time() - start:.3f} s')
print(f'Mem usage: {get_mem_usage() / 1024**2:.3f} MB')
```
不使用 NVME 卸载:
```python
train_cpu(0.0)
```
我们可能得到如下输出:
```
Model numel: 0.116 B
[0] loss: 10.953
[1] loss: 10.974
[2] loss: 10.965
Time: 7.739 s
Mem usage: 5966.445 MB
```
然后使用(全量) NVME 卸载:
```python
train_cpu(1.0)
```
我们可能得到:
```
Model numel: 0.116 B
[0] loss: 10.951
[1] loss: 10.994
[2] loss: 10.984
Time: 8.527 s
Mem usage: 4968.016 MB
```
对于有1.16亿参数的 GPT2-S 来说,它的优化器状态大约需要占用 0.928 GB 内存。NVME 卸载节省了大约 998 MB 内存,符合我们的预期。
然后我们可以用 Gemini 来训练 GPT 模型。放置策略应该设置为`"auto"`、 `"cpu"``"const"`
```python
def train_gemini_cpu(nvme_offload_fraction: float = 0.0):
colossalai.launch_from_torch({})
config = GPT2Config()
with ColoInitContext(device=torch.cuda.current_device()):
model = GPT2LMHeadModel(config)
criterion = GPTLMLoss()
optimizer = HybridAdam(model.parameters(), nvme_offload_fraction=nvme_offload_fraction)
print(f'Model numel: {get_model_numel(model) / 1024**3:.3f} B')
gemini_config = dict(strict_ddp_mode=True, device=torch.cuda.current_device(),
placement_policy='cpu', pin_memory=True, hidden_dim=config.n_embd)
model = zero_model_wrapper(model, zero_stage=3, gemini_config=gemini_config)
optimizer = zero_optim_wrapper(model, optimizer, initial_scale=2**5)
start = time.time()
for step in range(3):
data = get_data(4, 128, config.vocab_size)
outputs = model(**data)
loss = criterion(outputs.logits, data['input_ids'])
optimizer.backward(loss)
optimizer.step()
optimizer.zero_grad()
print(f'[{step}] loss: {loss.item():.3f}')
print(f'Time: {time.time() - start:.3f} s')
print(f'Mem usage: {get_mem_usage() / 1024**2:.3f} MB')
```
不使用 NVME 卸载:
```python
train_gemini_cpu(0.0)
```
我们可能得到:
```
Model numel: 0.116 B
searching chunk configuration is completed in 0.27 s.
used number: 118.68 MB, wasted number: 0.75 MB
total wasted percentage is 0.63%
[0] loss: 10.953
[1] loss: 10.938
[2] loss: 10.969
Time: 2.997 s
Mem usage: 5592.227 MB
```
然后使用(全量) NVME 卸载:
```python
train_gemini_cpu(1.0)
```
我们可能得到:
```
Model numel: 0.116 B
searching chunk configuration is completed in 0.27 s.
used number: 118.68 MB, wasted number: 0.75 MB
total wasted percentage is 0.63%
[0] loss: 10.953
[1] loss: 10.938
[2] loss: 10.969
Time: 3.691 s
Mem usage: 5298.344 MB
```
NVME 卸载节省了大约 294 MB 内存。注意使用 Gemini 的 `pin_memory` 功能可以加速训练,但是会增加内存占用。所以这个结果也是符合我们预期的。如果我们关闭 `pin_memory`,我们仍然可以观察到大约 900 MB 的内存占用下降。
## API 参考
{{ autodoc:colossalai.nn.optimizer.HybridAdam }}
{{ autodoc:colossalai.nn.optimizer.CPUAdam }}

Loading…
Cancel
Save