[tutorial] modify hands-on of auto activation checkpoint (#1920)

* [sc] SC tutorial for auto checkpoint

* [sc] polish examples

* [sc] polish readme

* [sc] polish readme and help information

* [sc] polish readme and help information

* [sc] modify auto checkpoint benchmark

* [sc] remove imgs
pull/1922/head
Boyuan Yao 2022-11-12 18:21:03 +08:00 committed by GitHub
parent ff16773ded
commit 24cbee0ebe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 135 additions and 192 deletions

View File

@ -19,79 +19,38 @@ colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py
## Auto Checkpoint Benchmarking
We prepare three demos for you to test the performance of auto checkpoint, the test `demo_resnet50.py` and `demo_gpt2_medium.py` will show you the ability of solver to search checkpoint strategy that could fit in the given budget.
We prepare two bechmarks for you to test the performance of auto checkpoint
The first test `auto_ckpt_solver_test.py` will show you the ability of solver to search checkpoint strategy that could fit in the given budget (test on GPT2 Medium and ResNet 50). It will output the benchmark summary and data visualization of peak memory vs. budget memory and relative step time vs. peak memory.
The second test `auto_ckpt_batchsize_test.py` will show you the advantage of fitting larger batchsize training into limited GPU memory with the help of our activation checkpoint solver (test on ResNet152). It will output the benchmark summary.
The usage of the above two test
```bash
python demo_resnet50.py --help
usage: ResNet50 Auto Activation Benchmark [-h] [--batch_size BATCH_SIZE] [--num_steps NUM_STEPS] [--sample_points SAMPLE_POINTS] [--free_memory FREE_MEMORY]
[--start_factor START_FACTOR]
# run auto_ckpt_solver_test.py on gpt2 medium
python auto_ckpt_solver_test.py --model gpt2
optional arguments:
-h, --help show this help message and exit
--batch_size BATCH_SIZE
batch size for benchmark, default 128
--num_steps NUM_STEPS
number of test steps for benchmark, default 5
--sample_points SAMPLE_POINTS
number of sample points for benchmark from start memory budget to maximum memory budget (free_memory), default 15
--free_memory FREE_MEMORY
maximum memory budget in MB for benchmark, default 11000 MB
--start_factor START_FACTOR
start memory budget factor for benchmark, the start memory budget will be free_memory / start_factor, default 4
# run auto_ckpt_solver_test.py on resnet50
python auto_ckpt_solver_test.py --model resnet50
# run with default settings
python demo_resnet50.py
python demo_gpt2_medium.py --help
usage: GPT2 medium Auto Activation Benchmark [-h] [--batch_size BATCH_SIZE] [--num_steps NUM_STEPS] [--sample_points SAMPLE_POINTS] [--free_memory FREE_MEMORY]
[--start_factor START_FACTOR]
optional arguments:
-h, --help show this help message and exit
--batch_size BATCH_SIZE
batch size for benchmark, default 8
--num_steps NUM_STEPS
number of test steps for benchmark, default 5
--sample_points SAMPLE_POINTS
number of sample points for benchmark from start memory budget to maximum memory budget (free_memory), default 15
--free_memory FREE_MEMORY
maximum memory budget in MB for benchmark, default 56000 MB
--start_factor START_FACTOR
start memory budget factor for benchmark, the start memory budget will be free_memory / start_factor, default 10
# run with default settings
python demo_gpt2_medium.py
# tun auto_ckpt_batchsize_test.py
python auto_ckpt_batchsize_test.py
```
There are some results for your reference
## Auto Checkpoint Solver Test
### ResNet 50
![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/tutorial/resnet50_benchmark.png)
### GPT2 Medium
![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/tutorial/gpt2_benchmark.png)
We also prepare the demo `demo_resnet152.py` to manifest the benefit of auto activation with large batch, the usage is listed as follows
```bash
python demo_resnet152.py --help
usage: ResNet152 Auto Activation Through Put Benchmark [-h] [--num_steps NUM_STEPS]
optional arguments:
-h, --help show this help message and exit
--num_steps NUM_STEPS
number of test steps for benchmark, default 5
# run with default settings
python demo_resnet152.py
```
here are some results on our end for your reference
## Auto Checkpoint Batch Size Test
```bash
===============test summary================
batch_size: 512, peak memory: 73314.392 MB, through put: 254.286 images/s
batch_size: 1024, peak memory: 73316.216 MB, through put: 397.608 images/s
batch_size: 2048, peak memory: 72927.837 MB, through put: 277.429 images/s
```
The above tests will output the test summary and a plot of the benchmarking results.

View File

@ -8,7 +8,7 @@ import numpy as np
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from bench_utils import bench
from bench_utils import bench, data_gen_resnet
import colossalai
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
@ -16,19 +16,14 @@ from colossalai.fx import metainfo_trace, symbolic_trace
from colossalai.utils import free_port
def data_gen(batch_size, shape, device='cuda'):
"""
Generate random data for benchmarking
"""
data = torch.empty(batch_size, *shape, device=device)
label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)
return (data,), label
def _resnet152_benchmark(rank, world_size, port, num_steps):
"""Resnet152 benchmark
def _benchmark(rank, world_size, port):
"""Auto activation checkpoint batchsize benchmark
This benchmark test the through put of Resnet152 with our activation solver given the memory budget of 95% of
maximum GPU memory, and with the batch size of [512, 1024, 2048]
maximum GPU memory, and with the batch size of [512, 1024, 2048], you could see that using auto activation
checkpoint with optimality guarantee, we might be able to find better batch size for the model, as larger batch
size means that we are able to use larger portion of GPU FLOPS, while recomputation scheduling with our solver
only result in minor performance drop. So at last we might be able to find better training batch size for our
model (combine with large batch training optimizer such as LAMB).
"""
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = tm.resnet152()
@ -42,33 +37,23 @@ def _resnet152_benchmark(rank, world_size, port, num_steps):
gm.graph = solver.solve()
peak_mem, step_time = bench(gm,
torch.nn.CrossEntropyLoss(),
partial(data_gen, batch_size=batch_size, shape=(3, 224, 224)),
num_steps=num_steps)
partial(data_gen_resnet, batch_size=batch_size, shape=(3, 224, 224)),
num_steps=5)
peak_mems.append(peak_mem)
through_puts.append(batch_size / step_time * 1.0e3)
gm.graph = deepcopy(raw_graph)
# print results
print("===============test summary================")
print("===============benchmark summary================")
for batch_size, peak_mem, through_put in zip(batch_sizes, peak_mems, through_puts):
print(f'batch_size: {int(batch_size)}, peak memory: {peak_mem:.3f} MB, through put: {through_put:.3f} images/s')
plt.plot(batch_sizes, through_puts)
plt.xlabel("batch size")
plt.ylabel("through put (images/s)")
plt.title("Resnet152 benchmark")
plt.savefig("resnet152_benchmark.png")
def resnet152_benchmark(num_steps):
def auto_activation_checkpoint_batchsize_benchmark():
world_size = 1
run_func_module = partial(_resnet152_benchmark, world_size=world_size, port=free_port(), num_steps=num_steps)
run_func_module = partial(_benchmark, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == "__main__":
parser = ArgumentParser("ResNet152 Auto Activation Through Put Benchmark")
parser.add_argument("--num_steps", type=int, default=5, help="number of test steps for benchmark, default 5")
args = parser.parse_args()
resnet152_benchmark(args.num_steps)
auto_activation_checkpoint_batchsize_benchmark()

View File

@ -0,0 +1,89 @@
import time
from argparse import ArgumentParser
from functools import partial
import matplotlib.pyplot as plt
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium
import colossalai
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace, symbolic_trace
from colossalai.utils import free_port
def _benchmark(rank, world_size, port, args):
"""
Auto activation checkpoint solver benchmark, we provide benchmark on two models: gpt2_medium and resnet50.
The benchmark will sample in a range of memory budget for each model and output the benchmark summary and
data visualization of peak memory vs. budget memory and relative step time vs. peak memory.
"""
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
if args.model == 'resnet50':
model = tm.resnet50()
data_gen = partial(data_gen_resnet, batch_size=128, shape=(3, 224, 224))
gm = symbolic_trace(model)
gm = metainfo_trace(gm, torch.empty(128, 3, 224, 224, device='meta'))
loss = torch.nn.CrossEntropyLoss()
else:
model = gpt2_medium()
data_gen = partial(data_gen_gpt2, batch_size=8, seq_len=1024, vocab_size=50257)
data, mask = data_gen(device='meta')[0]
gm = symbolic_trace(model, meta_args={'input_ids': data, 'attention_mask': mask})
gm = metainfo_trace(gm, data, mask)
loss = GPTLMLoss()
free_memory = 11000 * 1024**2 if args.model == 'resnet50' else 56000 * 1024**2
start_factor = 4 if args.model == 'resnet50' else 10
# trace and benchmark
budgets, peak_hist, step_hist = bench_rotor(gm,
loss,
data_gen,
num_steps=5,
sample_points=15,
free_memory=free_memory,
start_factor=start_factor)
# print summary
print("==============benchmark summary==============")
for budget, peak, step in zip(budgets, peak_hist, step_hist):
print(f'memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS')
# plot valid results
fig, axs = plt.subplots(1, 2, figsize=(16, 8))
valid_idx = step_hist.index(next(step for step in step_hist if step != float("inf")))
# plot peak memory vs. budget memory
axs[0].plot(budgets[valid_idx:], peak_hist[valid_idx:])
axs[0].plot([budgets[valid_idx], budgets[-1]], [budgets[valid_idx], budgets[-1]], linestyle='--')
axs[0].set_xlabel("Budget Memory (MB)")
axs[0].set_ylabel("Peak Memory (MB)")
axs[0].set_title("Peak Memory vs. Budget Memory")
# plot relative step time vs. budget memory
axs[1].plot(peak_hist[valid_idx:], [step_time / step_hist[-1] for step_time in step_hist[valid_idx:]])
axs[1].plot([peak_hist[valid_idx], peak_hist[-1]], [1.0, 1.0], linestyle='--')
axs[1].set_xlabel("Peak Memory (MB)")
axs[1].set_ylabel("Relative Step Time")
axs[1].set_title("Step Time vs. Peak Memory")
axs[1].set_ylim(0.8, 1.5)
# save plot
fig.savefig(f"{args.model}_benchmark.png")
def auto_activation_checkpoint_benchmark(args):
world_size = 1
run_func_module = partial(_benchmark, world_size=world_size, port=free_port(), args=args)
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == "__main__":
parser = ArgumentParser("Auto Activation Checkpoint Solver Benchmark")
parser.add_argument("--model", type=str, default='gpt2', choices=['gpt2', 'resnet50'])
args = parser.parse_args()
auto_activation_checkpoint_benchmark(args)

View File

@ -1,108 +0,0 @@
import time
from argparse import ArgumentParser
from functools import partial
import matplotlib.pyplot as plt
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from bench_utils import GPTLMLoss, bench_rotor, gpt2_medium
import colossalai
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace, symbolic_trace
from colossalai.utils import free_port
def data_gen(batch_size, seq_len, vocab_size, device='cuda:0'):
"""
Generate random data for benchmarking
"""
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
attention_mask = torch.ones_like(input_ids, device=device)
return (input_ids, attention_mask), attention_mask
def _gpt2_benchmark(rank, world_size, port, batch_size, num_steps, sample_points, free_memory, start_factor):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = gpt2_medium()
# trace and benchmark
data, mask = data_gen(batch_size, 1024, 50257, device='meta')[0]
gm = symbolic_trace(model, meta_args={'input_ids': data, 'attention_mask': mask})
gm = metainfo_trace(gm, data, mask)
budgets, peak_hist, step_hist = bench_rotor(gm,
GPTLMLoss(),
partial(data_gen, batch_size=batch_size, seq_len=1024,
vocab_size=50257),
num_steps=num_steps,
sample_points=sample_points,
free_memory=free_memory,
start_factor=start_factor)
# print summary
print("==============test summary==============")
for budget, peak, step in zip(budgets, peak_hist, step_hist):
print(f'memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS')
# plot valid results
fig, axs = plt.subplots(1, 2, figsize=(16, 8))
valid_idx = step_hist.index(next(step for step in step_hist if step != float("inf")))
# plot peak memory vs. budget memory
axs[0].plot(budgets[valid_idx:], peak_hist[valid_idx:])
axs[0].plot([budgets[valid_idx], budgets[-1]], [budgets[valid_idx], budgets[-1]], linestyle='--')
axs[0].set_xlabel("Budget Memory (MB)")
axs[0].set_ylabel("Peak Memory (MB)")
axs[0].set_title("Peak Memory vs. Budget Memory")
# plot relative step time vs. budget memory
axs[1].plot(peak_hist[valid_idx:], [step_time / step_hist[-1] for step_time in step_hist[valid_idx:]])
axs[1].plot([peak_hist[valid_idx], peak_hist[-1]], [1.0, 1.0], linestyle='--')
axs[1].set_xlabel("Peak Memory (MB)")
axs[1].set_ylabel("Relative Step Time")
axs[1].set_title("Step Time vs. Peak Memory")
axs[1].set_ylim(0.8, 1.5)
# save plot
fig.savefig("gpt2_benchmark.png")
def gpt2_benchmark(batch_size, num_steps, sample_points, free_memory, start_factor):
world_size = 1
run_func_module = partial(_gpt2_benchmark,
world_size=world_size,
port=free_port(),
batch_size=batch_size,
num_steps=num_steps,
sample_points=sample_points,
free_memory=free_memory,
start_factor=start_factor)
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == "__main__":
parser = ArgumentParser("GPT2 medium Auto Activation Benchmark")
parser.add_argument("--batch_size", type=int, default=8, help="batch size for benchmark, default 8")
parser.add_argument("--num_steps", type=int, default=5, help="number of test steps for benchmark, default 5")
parser.add_argument(
"--sample_points",
type=int,
default=15,
help=
"number of sample points for benchmark from start memory budget to maximum memory budget (free_memory), default 15"
)
parser.add_argument("--free_memory",
type=int,
default=56000,
help="maximum memory budget in MB for benchmark, default 56000 MB")
parser.add_argument(
"--start_factor",
type=int,
default=10,
help=
"start memory budget factor for benchmark, the start memory budget will be free_memory / start_factor, default 10"
)
args = parser.parse_args()
gpt2_benchmark(args.batch_size, args.num_steps, args.sample_points, args.free_memory * 1024**2, args.start_factor)

View File

@ -154,3 +154,21 @@ def gpt2_xl(checkpoint=False):
def gpt2_6b(checkpoint=False):
return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)
def data_gen_gpt2(batch_size, seq_len, vocab_size, device='cuda:0'):
"""
Generate random data for gpt2 benchmarking
"""
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
attention_mask = torch.ones_like(input_ids, device=device)
return (input_ids, attention_mask), attention_mask
def data_gen_resnet(batch_size, shape, device='cuda:0'):
"""
Generate random data for resnet benchmarking
"""
data = torch.empty(batch_size, *shape, device=device)
label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)
return (data,), label