mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
81 lines
3.3 KiB
81 lines
3.3 KiB
from argparse import ArgumentParser |
|
from functools import partial |
|
|
|
import matplotlib.pyplot as plt |
|
import torch |
|
import torchvision.models as tm |
|
from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium |
|
|
|
import colossalai |
|
from colossalai.fx import metainfo_trace, symbolic_trace |
|
from colossalai.testing import spawn |
|
|
|
|
|
def _benchmark(rank, world_size, port, args): |
|
""" |
|
Auto activation checkpoint solver benchmark, we provide benchmark on two models: gpt2_medium and resnet50. |
|
The benchmark will sample in a range of memory budget for each model and output the benchmark summary and |
|
data visualization of peak memory vs. budget memory and relative step time vs. peak memory. |
|
""" |
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") |
|
if args.model == "resnet50": |
|
model = tm.resnet50() |
|
data_gen = partial(data_gen_resnet, batch_size=128, shape=(3, 224, 224)) |
|
gm = symbolic_trace(model) |
|
gm = metainfo_trace(gm, torch.empty(128, 3, 224, 224, device="meta")) |
|
loss = torch.nn.CrossEntropyLoss() |
|
else: |
|
model = gpt2_medium() |
|
data_gen = partial(data_gen_gpt2, batch_size=8, seq_len=1024, vocab_size=50257) |
|
data, mask = data_gen(device="meta")[0] |
|
gm = symbolic_trace(model, meta_args={"input_ids": data, "attention_mask": mask}) |
|
gm = metainfo_trace(gm, data, mask) |
|
loss = GPTLMLoss() |
|
|
|
free_memory = 11000 * 1024**2 if args.model == "resnet50" else 56000 * 1024**2 |
|
start_factor = 4 if args.model == "resnet50" else 10 |
|
|
|
# trace and benchmark |
|
budgets, peak_hist, step_hist = bench_rotor( |
|
gm, loss, data_gen, num_steps=5, sample_points=15, free_memory=free_memory, start_factor=start_factor |
|
) |
|
|
|
# print summary |
|
print("==============benchmark summary==============") |
|
for budget, peak, step in zip(budgets, peak_hist, step_hist): |
|
print(f"memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS") |
|
|
|
# plot valid results |
|
fig, axs = plt.subplots(1, 2, figsize=(16, 8)) |
|
valid_idx = step_hist.index(next(step for step in step_hist if step != float("inf"))) |
|
|
|
# plot peak memory vs. budget memory |
|
axs[0].plot(budgets[valid_idx:], peak_hist[valid_idx:]) |
|
axs[0].plot([budgets[valid_idx], budgets[-1]], [budgets[valid_idx], budgets[-1]], linestyle="--") |
|
axs[0].set_xlabel("Budget Memory (MB)") |
|
axs[0].set_ylabel("Peak Memory (MB)") |
|
axs[0].set_title("Peak Memory vs. Budget Memory") |
|
|
|
# plot relative step time vs. budget memory |
|
axs[1].plot(peak_hist[valid_idx:], [step_time / step_hist[-1] for step_time in step_hist[valid_idx:]]) |
|
axs[1].plot([peak_hist[valid_idx], peak_hist[-1]], [1.0, 1.0], linestyle="--") |
|
axs[1].set_xlabel("Peak Memory (MB)") |
|
axs[1].set_ylabel("Relative Step Time") |
|
axs[1].set_title("Step Time vs. Peak Memory") |
|
axs[1].set_ylim(0.8, 1.5) |
|
|
|
# save plot |
|
fig.savefig(f"{args.model}_benchmark.png") |
|
|
|
|
|
def auto_activation_checkpoint_benchmark(args): |
|
world_size = 1 |
|
spawn(_benchmark, world_size, args=args) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = ArgumentParser("Auto Activation Checkpoint Solver Benchmark") |
|
parser.add_argument("--model", type=str, default="gpt2", choices=["gpt2", "resnet50"]) |
|
args = parser.parse_args() |
|
|
|
auto_activation_checkpoint_benchmark(args)
|
|
|