mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
107 lines
4.5 KiB
107 lines
4.5 KiB
import time |
|
from argparse import ArgumentParser |
|
from functools import partial |
|
|
|
import matplotlib.pyplot as plt |
|
import torch |
|
import torch.multiprocessing as mp |
|
import torchvision.models as tm |
|
from bench_utils import bench_rotor |
|
|
|
import colossalai |
|
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor |
|
from colossalai.fx import metainfo_trace, symbolic_trace |
|
from colossalai.utils import free_port |
|
|
|
|
|
def data_gen(batch_size, shape, device='cuda'): |
|
""" |
|
Generate random data for benchmarking |
|
""" |
|
data = torch.empty(batch_size, *shape, device=device) |
|
label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000) |
|
return (data,), label |
|
|
|
|
|
def _resnet50_benchmark(rank, world_size, port, batch_size, num_steps, sample_points, free_memory, start_factor): |
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') |
|
model = tm.resnet50() |
|
|
|
# trace and benchmark |
|
gm = symbolic_trace(model) |
|
gm = metainfo_trace(gm, torch.empty(batch_size, 3, 224, 224, device='meta')) |
|
budgets, peak_hist, step_hist = bench_rotor(gm, |
|
torch.nn.CrossEntropyLoss(), |
|
partial(data_gen, batch_size=batch_size, shape=(3, 224, 224)), |
|
num_steps=num_steps, |
|
sample_points=sample_points, |
|
free_memory=free_memory, |
|
start_factor=start_factor) |
|
|
|
# print summary |
|
print("==============test summary==============") |
|
for budget, peak, step in zip(budgets, peak_hist, step_hist): |
|
print(f'memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS') |
|
|
|
# plot valid results |
|
fig, axs = plt.subplots(1, 2, figsize=(16, 8)) |
|
valid_idx = step_hist.index(next(step for step in step_hist if step != float("inf"))) |
|
|
|
# plot peak memory vs. budget memory |
|
axs[0].plot(budgets[valid_idx:], peak_hist[valid_idx:]) |
|
axs[0].plot([budgets[valid_idx], budgets[-1]], [budgets[valid_idx], budgets[-1]], linestyle='--') |
|
axs[0].set_xlabel("Budget Memory (MB)") |
|
axs[0].set_ylabel("Peak Memory (MB)") |
|
axs[0].set_title("Peak Memory vs. Budget Memory") |
|
|
|
# plot relative step time vs. budget memory |
|
axs[1].plot(peak_hist[valid_idx:], [step_time / step_hist[-1] for step_time in step_hist[valid_idx:]]) |
|
axs[1].plot([peak_hist[valid_idx], peak_hist[-1]], [1.0, 1.0], linestyle='--') |
|
axs[1].set_xlabel("Peak Memory (MB)") |
|
axs[1].set_ylabel("Relative Step Time") |
|
axs[1].set_title("Step Time vs. Peak Memory") |
|
axs[1].set_ylim(0.8, 1.5) |
|
|
|
# save plot |
|
fig.savefig("resnet50_benchmark.png") |
|
|
|
|
|
def resnet50_benchmark(batch_size, num_steps, sample_points, free_memory, start_factor): |
|
world_size = 1 |
|
run_func_module = partial(_resnet50_benchmark, |
|
world_size=world_size, |
|
port=free_port(), |
|
batch_size=batch_size, |
|
num_steps=num_steps, |
|
sample_points=sample_points, |
|
free_memory=free_memory, |
|
start_factor=start_factor) |
|
mp.spawn(run_func_module, nprocs=world_size) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = ArgumentParser("ResNet50 Auto Activation Benchmark") |
|
parser.add_argument("--batch_size", type=int, default=128, help="batch size for benchmark, default 128") |
|
parser.add_argument("--num_steps", type=int, default=5, help="number of test steps for benchmark, default 5") |
|
parser.add_argument( |
|
"--sample_points", |
|
type=int, |
|
default=15, |
|
help= |
|
"number of sample points for benchmark from start memory budget to maximum memory budget (free_memory), default 15" |
|
) |
|
parser.add_argument("--free_memory", |
|
type=int, |
|
default=11000, |
|
help="maximum memory budget in MB for benchmark, default 11000 MB") |
|
parser.add_argument( |
|
"--start_factor", |
|
type=int, |
|
default=4, |
|
help= |
|
"start memory budget factor for benchmark, the start memory budget will be free_memory / start_factor, default 4" |
|
) |
|
args = parser.parse_args() |
|
|
|
resnet50_benchmark(args.batch_size, args.num_steps, args.sample_points, args.free_memory * 1024**2, |
|
args.start_factor)
|
|
|