[SC] add GPT example for auto checkpoint (#1889)

* [sc] SC tutorial for auto checkpoint

* [sc] polish examples

* [sc] polish readme

* [sc] polish readme and help information

* [sc] polish readme and help information
pull/1910/head
Boyuan Yao 2022-11-11 23:17:25 +08:00 committed by GitHub
parent 11ee8ae478
commit d5c5bc219e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 470 additions and 889 deletions

View File

@ -338,7 +338,7 @@ def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit:
Returns:
torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo.
"""
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
interp = MetaInfoProp(gm.to(device))
if is_compatible_with_meta():
from colossalai.fx.profiler import MetaTensor

View File

@ -15,3 +15,82 @@ export DATA=/path/to/data
```bash
colossalai run --nproc_per_node 4 auto_parallel_demo.py
```
## Auto Checkpoint Benchmarking
We prepare three demos for you to test the performance of auto checkpoint, the test `demo_resnet50.py` and `demo_gpt2_medium.py` will show you the ability of solver to search checkpoint strategy that could fit in the given budget.
The usage of the above two test
```bash
python demo_resnet50.py --help
usage: ResNet50 Auto Activation Benchmark [-h] [--batch_size BATCH_SIZE] [--num_steps NUM_STEPS] [--sample_points SAMPLE_POINTS] [--free_memory FREE_MEMORY]
[--start_factor START_FACTOR]
optional arguments:
-h, --help show this help message and exit
--batch_size BATCH_SIZE
batch size for benchmark, default 128
--num_steps NUM_STEPS
number of test steps for benchmark, default 5
--sample_points SAMPLE_POINTS
number of sample points for benchmark from start memory budget to maximum memory budget (free_memory), default 15
--free_memory FREE_MEMORY
maximum memory budget in MB for benchmark, default 11000 MB
--start_factor START_FACTOR
start memory budget factor for benchmark, the start memory budget will be free_memory / start_factor, default 4
# run with default settings
python demo_resnet50.py
python demo_gpt2_medium.py --help
usage: GPT2 medium Auto Activation Benchmark [-h] [--batch_size BATCH_SIZE] [--num_steps NUM_STEPS] [--sample_points SAMPLE_POINTS] [--free_memory FREE_MEMORY]
[--start_factor START_FACTOR]
optional arguments:
-h, --help show this help message and exit
--batch_size BATCH_SIZE
batch size for benchmark, default 8
--num_steps NUM_STEPS
number of test steps for benchmark, default 5
--sample_points SAMPLE_POINTS
number of sample points for benchmark from start memory budget to maximum memory budget (free_memory), default 15
--free_memory FREE_MEMORY
maximum memory budget in MB for benchmark, default 56000 MB
--start_factor START_FACTOR
start memory budget factor for benchmark, the start memory budget will be free_memory / start_factor, default 10
# run with default settings
python demo_gpt2_medium.py
```
There are some results for your reference
### ResNet 50
![](./imgs/resnet50_benchmark.png)
### GPT2 Medium
![](./imgs/gpt2_benchmark.png)
We also prepare the demo `demo_resnet152.py` to manifest the benefit of auto activation with large batch, the usage is listed as follows
```bash
python demo_resnet152.py --help
usage: ResNet152 Auto Activation Through Put Benchmark [-h] [--num_steps NUM_STEPS]
optional arguments:
-h, --help show this help message and exit
--num_steps NUM_STEPS
number of test steps for benchmark, default 5
# run with default settings
python demo_resnet152.py
```
here are some results on our end for your reference
```bash
===============test summary================
batch_size: 512, peak memory: 73314.392 MB, through put: 254.286 images/s
batch_size: 1024, peak memory: 73316.216 MB, through put: 397.608 images/s
batch_size: 2048, peak memory: 72927.837 MB, through put: 277.429 images/s
```
The above tests will output the test summary and a plot of the benchmarking results.

File diff suppressed because one or more lines are too long

View File

@ -1,16 +1,33 @@
import time
from copy import deepcopy
from functools import partial
from typing import Callable, Tuple
import numpy as np
import torch
import torch.nn as nn
import torchvision.models as tm
from transformers import GPT2Config, GPT2LMHeadModel
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace
def bench(gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callable, num_steps: int = 5):
def bench(gm: torch.fx.GraphModule,
criterion: torch.nn.Module,
data_gen: Callable,
num_steps: int = 5) -> Tuple[int, int]:
"""Benchmarking a given graph module
Args:
gm (torch.fx.GraphModule): The graph module to benchmark.
criterion (torch.nn.Module): Loss function.
data_gen (Callable): Data generator.
num_steps (int, optional): Number of test steps. Defaults to 5.
Returns:
Tuple[int, int]: peak memory in MB and step time in MS.
"""
gm.train()
gm.cuda()
step_time = float('inf')
@ -39,7 +56,8 @@ def bench(gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callab
del args, label, output, loss
gm.to("cpu")
torch.cuda.empty_cache()
return (torch.cuda.max_memory_allocated(device="cuda") - cached) / 1024**2, step_time * 1.0e3
peak_mem = (torch.cuda.max_memory_allocated(device="cuda") - cached) / 1024**2
return peak_mem, step_time * 1.0e3
def bench_rotor(gm: torch.fx.GraphModule,
@ -47,19 +65,92 @@ def bench_rotor(gm: torch.fx.GraphModule,
data_gen: Callable,
num_steps: int = 5,
sample_points: int = 20,
free_memory: int = torch.cuda.mem_get_info()[0]):
free_memory: int = torch.cuda.mem_get_info()[0],
start_factor: int = 4) -> Tuple[np.array, list, list]:
"""Auto Checkpoint Rotor Algorithm benchmarking
Benchmarks the Auto Checkpoint Rotor Algorithm for a given graph module and data.
Args:
gm (torch.fx.GraphModule): The graph module to benchmark.
criterion (torch.nn.Module): Loss function.
data_gen (Callable): Data generator.
num_steps (int, optional): Number of test steps. Defaults to 5.
sample_points (int, optional): Number of sample points. Defaults to 20.
free_memory (int, optional): Max memory budget in Byte. Defaults to torch.cuda.mem_get_info()[0].
start_factor (int, optional): Start memory budget factor for benchmark, the start memory budget
will be free_memory / start_factor. Defaults to 4.
Returns:
Tuple[np.array, list, list]: return budgets vector (MB), peak memory vector (MB), step time vector (MS).
"""
peak_hist, step_hist = [], []
for budget in np.linspace(free_memory // 5, free_memory, sample_points):
raw_graph = deepcopy(gm.graph)
for budget in np.linspace(free_memory // start_factor, free_memory, sample_points):
gm = metainfo_trace(gm, *data_gen()[0])
solver = CheckpointSolverRotor(gm.graph, free_memory=budget)
try:
gm.graph = solver.solve()
peak_memory, step_time = bench(gm,
criterion,
partial(data_gen, batch_size=2048, shape=(3, 224, 224)),
num_steps=num_steps)
gm.graph = solver.solve(verbose=False)
peak_memory, step_time = bench(gm, criterion, data_gen, num_steps=num_steps)
except:
peak_memory, step_time = budget / 1024**2, float('inf')
peak_hist.append(peak_memory)
step_hist.append(step_time)
return peak_hist, step_hist
gm.graph = deepcopy(raw_graph)
return np.linspace(free_memory // start_factor, free_memory, sample_points) / 1024**2, peak_hist, step_hist
class GPTLMModel(nn.Module):
"""
GPT Model
"""
def __init__(self,
hidden_size=768,
num_layers=12,
num_attention_heads=12,
max_seq_len=1024,
vocab_size=50257,
checkpoint=False):
super().__init__()
self.checkpoint = checkpoint
self.model = GPT2LMHeadModel(
GPT2Config(n_embd=hidden_size,
n_layer=num_layers,
n_head=num_attention_heads,
n_positions=max_seq_len,
n_ctx=max_seq_len,
vocab_size=vocab_size))
if checkpoint:
self.model.gradient_checkpointing_enable()
def forward(self, input_ids, attention_mask):
# Only return lm_logits
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
class GPTLMLoss(nn.Module):
"""
GPT Loss
"""
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
def gpt2_medium(checkpoint=False):
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
def gpt2_xl(checkpoint=False):
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint)
def gpt2_6b(checkpoint=False):
return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)

View File

@ -0,0 +1,108 @@
import time
from argparse import ArgumentParser
from functools import partial
import matplotlib.pyplot as plt
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from bench_utils import GPTLMLoss, bench_rotor, gpt2_medium
import colossalai
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace, symbolic_trace
from colossalai.utils import free_port
def data_gen(batch_size, seq_len, vocab_size, device='cuda:0'):
"""
Generate random data for benchmarking
"""
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
attention_mask = torch.ones_like(input_ids, device=device)
return (input_ids, attention_mask), attention_mask
def _gpt2_benchmark(rank, world_size, port, batch_size, num_steps, sample_points, free_memory, start_factor):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = gpt2_medium()
# trace and benchmark
data, mask = data_gen(batch_size, 1024, 50257, device='meta')[0]
gm = symbolic_trace(model, meta_args={'input_ids': data, 'attention_mask': mask})
gm = metainfo_trace(gm, data, mask)
budgets, peak_hist, step_hist = bench_rotor(gm,
GPTLMLoss(),
partial(data_gen, batch_size=batch_size, seq_len=1024,
vocab_size=50257),
num_steps=num_steps,
sample_points=sample_points,
free_memory=free_memory,
start_factor=start_factor)
# print summary
print("==============test summary==============")
for budget, peak, step in zip(budgets, peak_hist, step_hist):
print(f'memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS')
# plot valid results
fig, axs = plt.subplots(1, 2, figsize=(16, 8))
valid_idx = step_hist.index(next(step for step in step_hist if step != float("inf")))
# plot peak memory vs. budget memory
axs[0].plot(budgets[valid_idx:], peak_hist[valid_idx:])
axs[0].plot([budgets[valid_idx], budgets[-1]], [budgets[valid_idx], budgets[-1]], linestyle='--')
axs[0].set_xlabel("Budget Memory (MB)")
axs[0].set_ylabel("Peak Memory (MB)")
axs[0].set_title("Peak Memory vs. Budget Memory")
# plot relative step time vs. budget memory
axs[1].plot(peak_hist[valid_idx:], [step_time / step_hist[-1] for step_time in step_hist[valid_idx:]])
axs[1].plot([peak_hist[valid_idx], peak_hist[-1]], [1.0, 1.0], linestyle='--')
axs[1].set_xlabel("Peak Memory (MB)")
axs[1].set_ylabel("Relative Step Time")
axs[1].set_title("Step Time vs. Peak Memory")
axs[1].set_ylim(0.8, 1.5)
# save plot
fig.savefig("gpt2_benchmark.png")
def gpt2_benchmark(batch_size, num_steps, sample_points, free_memory, start_factor):
world_size = 1
run_func_module = partial(_gpt2_benchmark,
world_size=world_size,
port=free_port(),
batch_size=batch_size,
num_steps=num_steps,
sample_points=sample_points,
free_memory=free_memory,
start_factor=start_factor)
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == "__main__":
parser = ArgumentParser("GPT2 medium Auto Activation Benchmark")
parser.add_argument("--batch_size", type=int, default=8, help="batch size for benchmark, default 8")
parser.add_argument("--num_steps", type=int, default=5, help="number of test steps for benchmark, default 5")
parser.add_argument(
"--sample_points",
type=int,
default=15,
help=
"number of sample points for benchmark from start memory budget to maximum memory budget (free_memory), default 15"
)
parser.add_argument("--free_memory",
type=int,
default=56000,
help="maximum memory budget in MB for benchmark, default 56000 MB")
parser.add_argument(
"--start_factor",
type=int,
default=10,
help=
"start memory budget factor for benchmark, the start memory budget will be free_memory / start_factor, default 10"
)
args = parser.parse_args()
gpt2_benchmark(args.batch_size, args.num_steps, args.sample_points, args.free_memory * 1024**2, args.start_factor)

View File

@ -0,0 +1,74 @@
import time
from argparse import ArgumentParser
from copy import deepcopy
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from bench_utils import bench
import colossalai
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace, symbolic_trace
from colossalai.utils import free_port
def data_gen(batch_size, shape, device='cuda'):
"""
Generate random data for benchmarking
"""
data = torch.empty(batch_size, *shape, device=device)
label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)
return (data,), label
def _resnet152_benchmark(rank, world_size, port, num_steps):
"""Resnet152 benchmark
This benchmark test the through put of Resnet152 with our activation solver given the memory budget of 95% of
maximum GPU memory, and with the batch size of [512, 1024, 2048]
"""
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = tm.resnet152()
gm = symbolic_trace(model)
raw_graph = deepcopy(gm.graph)
peak_mems, through_puts, batch_sizes = [], [], [512, 1024, 2048]
for batch_size in batch_sizes:
batch_size = int(batch_size)
gm = metainfo_trace(gm, torch.empty(batch_size, 3, 224, 224, device='meta'))
solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info()[0] * 0.95)
gm.graph = solver.solve()
peak_mem, step_time = bench(gm,
torch.nn.CrossEntropyLoss(),
partial(data_gen, batch_size=batch_size, shape=(3, 224, 224)),
num_steps=num_steps)
peak_mems.append(peak_mem)
through_puts.append(batch_size / step_time * 1.0e3)
gm.graph = deepcopy(raw_graph)
# print results
print("===============test summary================")
for batch_size, peak_mem, through_put in zip(batch_sizes, peak_mems, through_puts):
print(f'batch_size: {int(batch_size)}, peak memory: {peak_mem:.3f} MB, through put: {through_put:.3f} images/s')
plt.plot(batch_sizes, through_puts)
plt.xlabel("batch size")
plt.ylabel("through put (images/s)")
plt.title("Resnet152 benchmark")
plt.savefig("resnet152_benchmark.png")
def resnet152_benchmark(num_steps):
world_size = 1
run_func_module = partial(_resnet152_benchmark, world_size=world_size, port=free_port(), num_steps=num_steps)
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == "__main__":
parser = ArgumentParser("ResNet152 Auto Activation Through Put Benchmark")
parser.add_argument("--num_steps", type=int, default=5, help="number of test steps for benchmark, default 5")
args = parser.parse_args()
resnet152_benchmark(args.num_steps)

View File

@ -0,0 +1,107 @@
import time
from argparse import ArgumentParser
from functools import partial
import matplotlib.pyplot as plt
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from bench_utils import bench_rotor
import colossalai
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace, symbolic_trace
from colossalai.utils import free_port
def data_gen(batch_size, shape, device='cuda'):
"""
Generate random data for benchmarking
"""
data = torch.empty(batch_size, *shape, device=device)
label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)
return (data,), label
def _resnet50_benchmark(rank, world_size, port, batch_size, num_steps, sample_points, free_memory, start_factor):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = tm.resnet50()
# trace and benchmark
gm = symbolic_trace(model)
gm = metainfo_trace(gm, torch.empty(batch_size, 3, 224, 224, device='meta'))
budgets, peak_hist, step_hist = bench_rotor(gm,
torch.nn.CrossEntropyLoss(),
partial(data_gen, batch_size=batch_size, shape=(3, 224, 224)),
num_steps=num_steps,
sample_points=sample_points,
free_memory=free_memory,
start_factor=start_factor)
# print summary
print("==============test summary==============")
for budget, peak, step in zip(budgets, peak_hist, step_hist):
print(f'memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS')
# plot valid results
fig, axs = plt.subplots(1, 2, figsize=(16, 8))
valid_idx = step_hist.index(next(step for step in step_hist if step != float("inf")))
# plot peak memory vs. budget memory
axs[0].plot(budgets[valid_idx:], peak_hist[valid_idx:])
axs[0].plot([budgets[valid_idx], budgets[-1]], [budgets[valid_idx], budgets[-1]], linestyle='--')
axs[0].set_xlabel("Budget Memory (MB)")
axs[0].set_ylabel("Peak Memory (MB)")
axs[0].set_title("Peak Memory vs. Budget Memory")
# plot relative step time vs. budget memory
axs[1].plot(peak_hist[valid_idx:], [step_time / step_hist[-1] for step_time in step_hist[valid_idx:]])
axs[1].plot([peak_hist[valid_idx], peak_hist[-1]], [1.0, 1.0], linestyle='--')
axs[1].set_xlabel("Peak Memory (MB)")
axs[1].set_ylabel("Relative Step Time")
axs[1].set_title("Step Time vs. Peak Memory")
axs[1].set_ylim(0.8, 1.5)
# save plot
fig.savefig("resnet50_benchmark.png")
def resnet50_benchmark(batch_size, num_steps, sample_points, free_memory, start_factor):
world_size = 1
run_func_module = partial(_resnet50_benchmark,
world_size=world_size,
port=free_port(),
batch_size=batch_size,
num_steps=num_steps,
sample_points=sample_points,
free_memory=free_memory,
start_factor=start_factor)
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == "__main__":
parser = ArgumentParser("ResNet50 Auto Activation Benchmark")
parser.add_argument("--batch_size", type=int, default=128, help="batch size for benchmark, default 128")
parser.add_argument("--num_steps", type=int, default=5, help="number of test steps for benchmark, default 5")
parser.add_argument(
"--sample_points",
type=int,
default=15,
help=
"number of sample points for benchmark from start memory budget to maximum memory budget (free_memory), default 15"
)
parser.add_argument("--free_memory",
type=int,
default=11000,
help="maximum memory budget in MB for benchmark, default 11000 MB")
parser.add_argument(
"--start_factor",
type=int,
default=4,
help=
"start memory budget factor for benchmark, the start memory budget will be free_memory / start_factor, default 4"
)
args = parser.parse_args()
resnet50_benchmark(args.batch_size, args.num_steps, args.sample_points, args.free_memory * 1024**2,
args.start_factor)

Binary file not shown.

After

Width:  |  Height:  |  Size: 65 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 71 KiB