mirror of https://github.com/hpcaitech/ColossalAI
[SC] add GPT example for auto checkpoint (#1889)
* [sc] SC tutorial for auto checkpoint * [sc] polish examples * [sc] polish readme * [sc] polish readme and help information * [sc] polish readme and help informationpull/1910/head
parent
11ee8ae478
commit
d5c5bc219e
|
@ -338,7 +338,7 @@ def metainfo_trace(gm: torch.fx.GraphModule, *args, verbose: bool = False, unit:
|
||||||
Returns:
|
Returns:
|
||||||
torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo.
|
torch.fx.GraphModule: The ``GraphModule`` annotated with MetaInfo.
|
||||||
"""
|
"""
|
||||||
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||||
interp = MetaInfoProp(gm.to(device))
|
interp = MetaInfoProp(gm.to(device))
|
||||||
if is_compatible_with_meta():
|
if is_compatible_with_meta():
|
||||||
from colossalai.fx.profiler import MetaTensor
|
from colossalai.fx.profiler import MetaTensor
|
||||||
|
|
|
@ -15,3 +15,82 @@ export DATA=/path/to/data
|
||||||
```bash
|
```bash
|
||||||
colossalai run --nproc_per_node 4 auto_parallel_demo.py
|
colossalai run --nproc_per_node 4 auto_parallel_demo.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Auto Checkpoint Benchmarking
|
||||||
|
|
||||||
|
We prepare three demos for you to test the performance of auto checkpoint, the test `demo_resnet50.py` and `demo_gpt2_medium.py` will show you the ability of solver to search checkpoint strategy that could fit in the given budget.
|
||||||
|
|
||||||
|
The usage of the above two test
|
||||||
|
```bash
|
||||||
|
python demo_resnet50.py --help
|
||||||
|
usage: ResNet50 Auto Activation Benchmark [-h] [--batch_size BATCH_SIZE] [--num_steps NUM_STEPS] [--sample_points SAMPLE_POINTS] [--free_memory FREE_MEMORY]
|
||||||
|
[--start_factor START_FACTOR]
|
||||||
|
|
||||||
|
optional arguments:
|
||||||
|
-h, --help show this help message and exit
|
||||||
|
--batch_size BATCH_SIZE
|
||||||
|
batch size for benchmark, default 128
|
||||||
|
--num_steps NUM_STEPS
|
||||||
|
number of test steps for benchmark, default 5
|
||||||
|
--sample_points SAMPLE_POINTS
|
||||||
|
number of sample points for benchmark from start memory budget to maximum memory budget (free_memory), default 15
|
||||||
|
--free_memory FREE_MEMORY
|
||||||
|
maximum memory budget in MB for benchmark, default 11000 MB
|
||||||
|
--start_factor START_FACTOR
|
||||||
|
start memory budget factor for benchmark, the start memory budget will be free_memory / start_factor, default 4
|
||||||
|
|
||||||
|
# run with default settings
|
||||||
|
python demo_resnet50.py
|
||||||
|
|
||||||
|
python demo_gpt2_medium.py --help
|
||||||
|
usage: GPT2 medium Auto Activation Benchmark [-h] [--batch_size BATCH_SIZE] [--num_steps NUM_STEPS] [--sample_points SAMPLE_POINTS] [--free_memory FREE_MEMORY]
|
||||||
|
[--start_factor START_FACTOR]
|
||||||
|
|
||||||
|
optional arguments:
|
||||||
|
-h, --help show this help message and exit
|
||||||
|
--batch_size BATCH_SIZE
|
||||||
|
batch size for benchmark, default 8
|
||||||
|
--num_steps NUM_STEPS
|
||||||
|
number of test steps for benchmark, default 5
|
||||||
|
--sample_points SAMPLE_POINTS
|
||||||
|
number of sample points for benchmark from start memory budget to maximum memory budget (free_memory), default 15
|
||||||
|
--free_memory FREE_MEMORY
|
||||||
|
maximum memory budget in MB for benchmark, default 56000 MB
|
||||||
|
--start_factor START_FACTOR
|
||||||
|
start memory budget factor for benchmark, the start memory budget will be free_memory / start_factor, default 10
|
||||||
|
|
||||||
|
# run with default settings
|
||||||
|
python demo_gpt2_medium.py
|
||||||
|
```
|
||||||
|
|
||||||
|
There are some results for your reference
|
||||||
|
|
||||||
|
### ResNet 50
|
||||||
|
![](./imgs/resnet50_benchmark.png)
|
||||||
|
|
||||||
|
### GPT2 Medium
|
||||||
|
![](./imgs/gpt2_benchmark.png)
|
||||||
|
|
||||||
|
We also prepare the demo `demo_resnet152.py` to manifest the benefit of auto activation with large batch, the usage is listed as follows
|
||||||
|
```bash
|
||||||
|
python demo_resnet152.py --help
|
||||||
|
usage: ResNet152 Auto Activation Through Put Benchmark [-h] [--num_steps NUM_STEPS]
|
||||||
|
|
||||||
|
optional arguments:
|
||||||
|
-h, --help show this help message and exit
|
||||||
|
--num_steps NUM_STEPS
|
||||||
|
number of test steps for benchmark, default 5
|
||||||
|
|
||||||
|
# run with default settings
|
||||||
|
python demo_resnet152.py
|
||||||
|
```
|
||||||
|
|
||||||
|
here are some results on our end for your reference
|
||||||
|
```bash
|
||||||
|
===============test summary================
|
||||||
|
batch_size: 512, peak memory: 73314.392 MB, through put: 254.286 images/s
|
||||||
|
batch_size: 1024, peak memory: 73316.216 MB, through put: 397.608 images/s
|
||||||
|
batch_size: 2048, peak memory: 72927.837 MB, through put: 277.429 images/s
|
||||||
|
```
|
||||||
|
|
||||||
|
The above tests will output the test summary and a plot of the benchmarking results.
|
||||||
|
|
File diff suppressed because one or more lines are too long
|
@ -1,16 +1,33 @@
|
||||||
import time
|
import time
|
||||||
|
from copy import deepcopy
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, Tuple
|
from typing import Callable, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torchvision.models as tm
|
import torchvision.models as tm
|
||||||
|
from transformers import GPT2Config, GPT2LMHeadModel
|
||||||
|
|
||||||
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
|
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
|
||||||
from colossalai.fx import metainfo_trace
|
from colossalai.fx import metainfo_trace
|
||||||
|
|
||||||
|
|
||||||
def bench(gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callable, num_steps: int = 5):
|
def bench(gm: torch.fx.GraphModule,
|
||||||
|
criterion: torch.nn.Module,
|
||||||
|
data_gen: Callable,
|
||||||
|
num_steps: int = 5) -> Tuple[int, int]:
|
||||||
|
"""Benchmarking a given graph module
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gm (torch.fx.GraphModule): The graph module to benchmark.
|
||||||
|
criterion (torch.nn.Module): Loss function.
|
||||||
|
data_gen (Callable): Data generator.
|
||||||
|
num_steps (int, optional): Number of test steps. Defaults to 5.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[int, int]: peak memory in MB and step time in MS.
|
||||||
|
"""
|
||||||
gm.train()
|
gm.train()
|
||||||
gm.cuda()
|
gm.cuda()
|
||||||
step_time = float('inf')
|
step_time = float('inf')
|
||||||
|
@ -39,7 +56,8 @@ def bench(gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callab
|
||||||
del args, label, output, loss
|
del args, label, output, loss
|
||||||
gm.to("cpu")
|
gm.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
return (torch.cuda.max_memory_allocated(device="cuda") - cached) / 1024**2, step_time * 1.0e3
|
peak_mem = (torch.cuda.max_memory_allocated(device="cuda") - cached) / 1024**2
|
||||||
|
return peak_mem, step_time * 1.0e3
|
||||||
|
|
||||||
|
|
||||||
def bench_rotor(gm: torch.fx.GraphModule,
|
def bench_rotor(gm: torch.fx.GraphModule,
|
||||||
|
@ -47,19 +65,92 @@ def bench_rotor(gm: torch.fx.GraphModule,
|
||||||
data_gen: Callable,
|
data_gen: Callable,
|
||||||
num_steps: int = 5,
|
num_steps: int = 5,
|
||||||
sample_points: int = 20,
|
sample_points: int = 20,
|
||||||
free_memory: int = torch.cuda.mem_get_info()[0]):
|
free_memory: int = torch.cuda.mem_get_info()[0],
|
||||||
|
start_factor: int = 4) -> Tuple[np.array, list, list]:
|
||||||
|
"""Auto Checkpoint Rotor Algorithm benchmarking
|
||||||
|
Benchmarks the Auto Checkpoint Rotor Algorithm for a given graph module and data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gm (torch.fx.GraphModule): The graph module to benchmark.
|
||||||
|
criterion (torch.nn.Module): Loss function.
|
||||||
|
data_gen (Callable): Data generator.
|
||||||
|
num_steps (int, optional): Number of test steps. Defaults to 5.
|
||||||
|
sample_points (int, optional): Number of sample points. Defaults to 20.
|
||||||
|
free_memory (int, optional): Max memory budget in Byte. Defaults to torch.cuda.mem_get_info()[0].
|
||||||
|
start_factor (int, optional): Start memory budget factor for benchmark, the start memory budget
|
||||||
|
will be free_memory / start_factor. Defaults to 4.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[np.array, list, list]: return budgets vector (MB), peak memory vector (MB), step time vector (MS).
|
||||||
|
"""
|
||||||
peak_hist, step_hist = [], []
|
peak_hist, step_hist = [], []
|
||||||
for budget in np.linspace(free_memory // 5, free_memory, sample_points):
|
raw_graph = deepcopy(gm.graph)
|
||||||
|
for budget in np.linspace(free_memory // start_factor, free_memory, sample_points):
|
||||||
gm = metainfo_trace(gm, *data_gen()[0])
|
gm = metainfo_trace(gm, *data_gen()[0])
|
||||||
solver = CheckpointSolverRotor(gm.graph, free_memory=budget)
|
solver = CheckpointSolverRotor(gm.graph, free_memory=budget)
|
||||||
try:
|
try:
|
||||||
gm.graph = solver.solve()
|
gm.graph = solver.solve(verbose=False)
|
||||||
peak_memory, step_time = bench(gm,
|
peak_memory, step_time = bench(gm, criterion, data_gen, num_steps=num_steps)
|
||||||
criterion,
|
|
||||||
partial(data_gen, batch_size=2048, shape=(3, 224, 224)),
|
|
||||||
num_steps=num_steps)
|
|
||||||
except:
|
except:
|
||||||
peak_memory, step_time = budget / 1024**2, float('inf')
|
peak_memory, step_time = budget / 1024**2, float('inf')
|
||||||
peak_hist.append(peak_memory)
|
peak_hist.append(peak_memory)
|
||||||
step_hist.append(step_time)
|
step_hist.append(step_time)
|
||||||
return peak_hist, step_hist
|
gm.graph = deepcopy(raw_graph)
|
||||||
|
return np.linspace(free_memory // start_factor, free_memory, sample_points) / 1024**2, peak_hist, step_hist
|
||||||
|
|
||||||
|
|
||||||
|
class GPTLMModel(nn.Module):
|
||||||
|
"""
|
||||||
|
GPT Model
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
hidden_size=768,
|
||||||
|
num_layers=12,
|
||||||
|
num_attention_heads=12,
|
||||||
|
max_seq_len=1024,
|
||||||
|
vocab_size=50257,
|
||||||
|
checkpoint=False):
|
||||||
|
super().__init__()
|
||||||
|
self.checkpoint = checkpoint
|
||||||
|
self.model = GPT2LMHeadModel(
|
||||||
|
GPT2Config(n_embd=hidden_size,
|
||||||
|
n_layer=num_layers,
|
||||||
|
n_head=num_attention_heads,
|
||||||
|
n_positions=max_seq_len,
|
||||||
|
n_ctx=max_seq_len,
|
||||||
|
vocab_size=vocab_size))
|
||||||
|
if checkpoint:
|
||||||
|
self.model.gradient_checkpointing_enable()
|
||||||
|
|
||||||
|
def forward(self, input_ids, attention_mask):
|
||||||
|
# Only return lm_logits
|
||||||
|
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
||||||
|
|
||||||
|
|
||||||
|
class GPTLMLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
GPT Loss
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.loss_fn = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
def forward(self, logits, labels):
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||||
|
|
||||||
|
|
||||||
|
def gpt2_medium(checkpoint=False):
|
||||||
|
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
|
||||||
|
|
||||||
|
|
||||||
|
def gpt2_xl(checkpoint=False):
|
||||||
|
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint)
|
||||||
|
|
||||||
|
|
||||||
|
def gpt2_6b(checkpoint=False):
|
||||||
|
return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)
|
||||||
|
|
|
@ -0,0 +1,108 @@
|
||||||
|
import time
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torchvision.models as tm
|
||||||
|
from bench_utils import GPTLMLoss, bench_rotor, gpt2_medium
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
|
||||||
|
from colossalai.fx import metainfo_trace, symbolic_trace
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen(batch_size, seq_len, vocab_size, device='cuda:0'):
|
||||||
|
"""
|
||||||
|
Generate random data for benchmarking
|
||||||
|
"""
|
||||||
|
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
|
||||||
|
attention_mask = torch.ones_like(input_ids, device=device)
|
||||||
|
return (input_ids, attention_mask), attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
def _gpt2_benchmark(rank, world_size, port, batch_size, num_steps, sample_points, free_memory, start_factor):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
model = gpt2_medium()
|
||||||
|
|
||||||
|
# trace and benchmark
|
||||||
|
data, mask = data_gen(batch_size, 1024, 50257, device='meta')[0]
|
||||||
|
gm = symbolic_trace(model, meta_args={'input_ids': data, 'attention_mask': mask})
|
||||||
|
gm = metainfo_trace(gm, data, mask)
|
||||||
|
budgets, peak_hist, step_hist = bench_rotor(gm,
|
||||||
|
GPTLMLoss(),
|
||||||
|
partial(data_gen, batch_size=batch_size, seq_len=1024,
|
||||||
|
vocab_size=50257),
|
||||||
|
num_steps=num_steps,
|
||||||
|
sample_points=sample_points,
|
||||||
|
free_memory=free_memory,
|
||||||
|
start_factor=start_factor)
|
||||||
|
|
||||||
|
# print summary
|
||||||
|
print("==============test summary==============")
|
||||||
|
for budget, peak, step in zip(budgets, peak_hist, step_hist):
|
||||||
|
print(f'memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS')
|
||||||
|
|
||||||
|
# plot valid results
|
||||||
|
fig, axs = plt.subplots(1, 2, figsize=(16, 8))
|
||||||
|
valid_idx = step_hist.index(next(step for step in step_hist if step != float("inf")))
|
||||||
|
|
||||||
|
# plot peak memory vs. budget memory
|
||||||
|
axs[0].plot(budgets[valid_idx:], peak_hist[valid_idx:])
|
||||||
|
axs[0].plot([budgets[valid_idx], budgets[-1]], [budgets[valid_idx], budgets[-1]], linestyle='--')
|
||||||
|
axs[0].set_xlabel("Budget Memory (MB)")
|
||||||
|
axs[0].set_ylabel("Peak Memory (MB)")
|
||||||
|
axs[0].set_title("Peak Memory vs. Budget Memory")
|
||||||
|
|
||||||
|
# plot relative step time vs. budget memory
|
||||||
|
axs[1].plot(peak_hist[valid_idx:], [step_time / step_hist[-1] for step_time in step_hist[valid_idx:]])
|
||||||
|
axs[1].plot([peak_hist[valid_idx], peak_hist[-1]], [1.0, 1.0], linestyle='--')
|
||||||
|
axs[1].set_xlabel("Peak Memory (MB)")
|
||||||
|
axs[1].set_ylabel("Relative Step Time")
|
||||||
|
axs[1].set_title("Step Time vs. Peak Memory")
|
||||||
|
axs[1].set_ylim(0.8, 1.5)
|
||||||
|
|
||||||
|
# save plot
|
||||||
|
fig.savefig("gpt2_benchmark.png")
|
||||||
|
|
||||||
|
|
||||||
|
def gpt2_benchmark(batch_size, num_steps, sample_points, free_memory, start_factor):
|
||||||
|
world_size = 1
|
||||||
|
run_func_module = partial(_gpt2_benchmark,
|
||||||
|
world_size=world_size,
|
||||||
|
port=free_port(),
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_steps=num_steps,
|
||||||
|
sample_points=sample_points,
|
||||||
|
free_memory=free_memory,
|
||||||
|
start_factor=start_factor)
|
||||||
|
mp.spawn(run_func_module, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = ArgumentParser("GPT2 medium Auto Activation Benchmark")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=8, help="batch size for benchmark, default 8")
|
||||||
|
parser.add_argument("--num_steps", type=int, default=5, help="number of test steps for benchmark, default 5")
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample_points",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help=
|
||||||
|
"number of sample points for benchmark from start memory budget to maximum memory budget (free_memory), default 15"
|
||||||
|
)
|
||||||
|
parser.add_argument("--free_memory",
|
||||||
|
type=int,
|
||||||
|
default=56000,
|
||||||
|
help="maximum memory budget in MB for benchmark, default 56000 MB")
|
||||||
|
parser.add_argument(
|
||||||
|
"--start_factor",
|
||||||
|
type=int,
|
||||||
|
default=10,
|
||||||
|
help=
|
||||||
|
"start memory budget factor for benchmark, the start memory budget will be free_memory / start_factor, default 10"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
gpt2_benchmark(args.batch_size, args.num_steps, args.sample_points, args.free_memory * 1024**2, args.start_factor)
|
|
@ -0,0 +1,74 @@
|
||||||
|
import time
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from copy import deepcopy
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torchvision.models as tm
|
||||||
|
from bench_utils import bench
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
|
||||||
|
from colossalai.fx import metainfo_trace, symbolic_trace
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen(batch_size, shape, device='cuda'):
|
||||||
|
"""
|
||||||
|
Generate random data for benchmarking
|
||||||
|
"""
|
||||||
|
data = torch.empty(batch_size, *shape, device=device)
|
||||||
|
label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)
|
||||||
|
return (data,), label
|
||||||
|
|
||||||
|
|
||||||
|
def _resnet152_benchmark(rank, world_size, port, num_steps):
|
||||||
|
"""Resnet152 benchmark
|
||||||
|
This benchmark test the through put of Resnet152 with our activation solver given the memory budget of 95% of
|
||||||
|
maximum GPU memory, and with the batch size of [512, 1024, 2048]
|
||||||
|
"""
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
model = tm.resnet152()
|
||||||
|
gm = symbolic_trace(model)
|
||||||
|
raw_graph = deepcopy(gm.graph)
|
||||||
|
peak_mems, through_puts, batch_sizes = [], [], [512, 1024, 2048]
|
||||||
|
for batch_size in batch_sizes:
|
||||||
|
batch_size = int(batch_size)
|
||||||
|
gm = metainfo_trace(gm, torch.empty(batch_size, 3, 224, 224, device='meta'))
|
||||||
|
solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info()[0] * 0.95)
|
||||||
|
gm.graph = solver.solve()
|
||||||
|
peak_mem, step_time = bench(gm,
|
||||||
|
torch.nn.CrossEntropyLoss(),
|
||||||
|
partial(data_gen, batch_size=batch_size, shape=(3, 224, 224)),
|
||||||
|
num_steps=num_steps)
|
||||||
|
peak_mems.append(peak_mem)
|
||||||
|
through_puts.append(batch_size / step_time * 1.0e3)
|
||||||
|
gm.graph = deepcopy(raw_graph)
|
||||||
|
|
||||||
|
# print results
|
||||||
|
print("===============test summary================")
|
||||||
|
for batch_size, peak_mem, through_put in zip(batch_sizes, peak_mems, through_puts):
|
||||||
|
print(f'batch_size: {int(batch_size)}, peak memory: {peak_mem:.3f} MB, through put: {through_put:.3f} images/s')
|
||||||
|
|
||||||
|
plt.plot(batch_sizes, through_puts)
|
||||||
|
plt.xlabel("batch size")
|
||||||
|
plt.ylabel("through put (images/s)")
|
||||||
|
plt.title("Resnet152 benchmark")
|
||||||
|
plt.savefig("resnet152_benchmark.png")
|
||||||
|
|
||||||
|
|
||||||
|
def resnet152_benchmark(num_steps):
|
||||||
|
world_size = 1
|
||||||
|
run_func_module = partial(_resnet152_benchmark, world_size=world_size, port=free_port(), num_steps=num_steps)
|
||||||
|
mp.spawn(run_func_module, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = ArgumentParser("ResNet152 Auto Activation Through Put Benchmark")
|
||||||
|
parser.add_argument("--num_steps", type=int, default=5, help="number of test steps for benchmark, default 5")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
resnet152_benchmark(args.num_steps)
|
|
@ -0,0 +1,107 @@
|
||||||
|
import time
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torchvision.models as tm
|
||||||
|
from bench_utils import bench_rotor
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
|
||||||
|
from colossalai.fx import metainfo_trace, symbolic_trace
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
|
|
||||||
|
def data_gen(batch_size, shape, device='cuda'):
|
||||||
|
"""
|
||||||
|
Generate random data for benchmarking
|
||||||
|
"""
|
||||||
|
data = torch.empty(batch_size, *shape, device=device)
|
||||||
|
label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)
|
||||||
|
return (data,), label
|
||||||
|
|
||||||
|
|
||||||
|
def _resnet50_benchmark(rank, world_size, port, batch_size, num_steps, sample_points, free_memory, start_factor):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
model = tm.resnet50()
|
||||||
|
|
||||||
|
# trace and benchmark
|
||||||
|
gm = symbolic_trace(model)
|
||||||
|
gm = metainfo_trace(gm, torch.empty(batch_size, 3, 224, 224, device='meta'))
|
||||||
|
budgets, peak_hist, step_hist = bench_rotor(gm,
|
||||||
|
torch.nn.CrossEntropyLoss(),
|
||||||
|
partial(data_gen, batch_size=batch_size, shape=(3, 224, 224)),
|
||||||
|
num_steps=num_steps,
|
||||||
|
sample_points=sample_points,
|
||||||
|
free_memory=free_memory,
|
||||||
|
start_factor=start_factor)
|
||||||
|
|
||||||
|
# print summary
|
||||||
|
print("==============test summary==============")
|
||||||
|
for budget, peak, step in zip(budgets, peak_hist, step_hist):
|
||||||
|
print(f'memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS')
|
||||||
|
|
||||||
|
# plot valid results
|
||||||
|
fig, axs = plt.subplots(1, 2, figsize=(16, 8))
|
||||||
|
valid_idx = step_hist.index(next(step for step in step_hist if step != float("inf")))
|
||||||
|
|
||||||
|
# plot peak memory vs. budget memory
|
||||||
|
axs[0].plot(budgets[valid_idx:], peak_hist[valid_idx:])
|
||||||
|
axs[0].plot([budgets[valid_idx], budgets[-1]], [budgets[valid_idx], budgets[-1]], linestyle='--')
|
||||||
|
axs[0].set_xlabel("Budget Memory (MB)")
|
||||||
|
axs[0].set_ylabel("Peak Memory (MB)")
|
||||||
|
axs[0].set_title("Peak Memory vs. Budget Memory")
|
||||||
|
|
||||||
|
# plot relative step time vs. budget memory
|
||||||
|
axs[1].plot(peak_hist[valid_idx:], [step_time / step_hist[-1] for step_time in step_hist[valid_idx:]])
|
||||||
|
axs[1].plot([peak_hist[valid_idx], peak_hist[-1]], [1.0, 1.0], linestyle='--')
|
||||||
|
axs[1].set_xlabel("Peak Memory (MB)")
|
||||||
|
axs[1].set_ylabel("Relative Step Time")
|
||||||
|
axs[1].set_title("Step Time vs. Peak Memory")
|
||||||
|
axs[1].set_ylim(0.8, 1.5)
|
||||||
|
|
||||||
|
# save plot
|
||||||
|
fig.savefig("resnet50_benchmark.png")
|
||||||
|
|
||||||
|
|
||||||
|
def resnet50_benchmark(batch_size, num_steps, sample_points, free_memory, start_factor):
|
||||||
|
world_size = 1
|
||||||
|
run_func_module = partial(_resnet50_benchmark,
|
||||||
|
world_size=world_size,
|
||||||
|
port=free_port(),
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_steps=num_steps,
|
||||||
|
sample_points=sample_points,
|
||||||
|
free_memory=free_memory,
|
||||||
|
start_factor=start_factor)
|
||||||
|
mp.spawn(run_func_module, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = ArgumentParser("ResNet50 Auto Activation Benchmark")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=128, help="batch size for benchmark, default 128")
|
||||||
|
parser.add_argument("--num_steps", type=int, default=5, help="number of test steps for benchmark, default 5")
|
||||||
|
parser.add_argument(
|
||||||
|
"--sample_points",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help=
|
||||||
|
"number of sample points for benchmark from start memory budget to maximum memory budget (free_memory), default 15"
|
||||||
|
)
|
||||||
|
parser.add_argument("--free_memory",
|
||||||
|
type=int,
|
||||||
|
default=11000,
|
||||||
|
help="maximum memory budget in MB for benchmark, default 11000 MB")
|
||||||
|
parser.add_argument(
|
||||||
|
"--start_factor",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help=
|
||||||
|
"start memory budget factor for benchmark, the start memory budget will be free_memory / start_factor, default 4"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
resnet50_benchmark(args.batch_size, args.num_steps, args.sample_points, args.free_memory * 1024**2,
|
||||||
|
args.start_factor)
|
Binary file not shown.
After Width: | Height: | Size: 65 KiB |
Binary file not shown.
After Width: | Height: | Size: 71 KiB |
Loading…
Reference in New Issue