ColossalAI/examples/tutorial/auto_parallel/auto_ckpt_solver_test.py

90 lines
3.8 KiB
Python
Raw Normal View History

import time
from argparse import ArgumentParser
from functools import partial
import matplotlib.pyplot as plt
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium
import colossalai
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
from colossalai.fx import metainfo_trace, symbolic_trace
from colossalai.utils import free_port
def _benchmark(rank, world_size, port, args):
"""
Auto activation checkpoint solver benchmark, we provide benchmark on two models: gpt2_medium and resnet50.
The benchmark will sample in a range of memory budget for each model and output the benchmark summary and
data visualization of peak memory vs. budget memory and relative step time vs. peak memory.
"""
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
if args.model == 'resnet50':
model = tm.resnet50()
data_gen = partial(data_gen_resnet, batch_size=128, shape=(3, 224, 224))
gm = symbolic_trace(model)
gm = metainfo_trace(gm, torch.empty(128, 3, 224, 224, device='meta'))
loss = torch.nn.CrossEntropyLoss()
else:
model = gpt2_medium()
data_gen = partial(data_gen_gpt2, batch_size=8, seq_len=1024, vocab_size=50257)
data, mask = data_gen(device='meta')[0]
gm = symbolic_trace(model, meta_args={'input_ids': data, 'attention_mask': mask})
gm = metainfo_trace(gm, data, mask)
loss = GPTLMLoss()
free_memory = 11000 * 1024**2 if args.model == 'resnet50' else 56000 * 1024**2
start_factor = 4 if args.model == 'resnet50' else 10
# trace and benchmark
budgets, peak_hist, step_hist = bench_rotor(gm,
loss,
data_gen,
num_steps=5,
sample_points=15,
free_memory=free_memory,
start_factor=start_factor)
# print summary
print("==============benchmark summary==============")
for budget, peak, step in zip(budgets, peak_hist, step_hist):
print(f'memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS')
# plot valid results
fig, axs = plt.subplots(1, 2, figsize=(16, 8))
valid_idx = step_hist.index(next(step for step in step_hist if step != float("inf")))
# plot peak memory vs. budget memory
axs[0].plot(budgets[valid_idx:], peak_hist[valid_idx:])
axs[0].plot([budgets[valid_idx], budgets[-1]], [budgets[valid_idx], budgets[-1]], linestyle='--')
axs[0].set_xlabel("Budget Memory (MB)")
axs[0].set_ylabel("Peak Memory (MB)")
axs[0].set_title("Peak Memory vs. Budget Memory")
# plot relative step time vs. budget memory
axs[1].plot(peak_hist[valid_idx:], [step_time / step_hist[-1] for step_time in step_hist[valid_idx:]])
axs[1].plot([peak_hist[valid_idx], peak_hist[-1]], [1.0, 1.0], linestyle='--')
axs[1].set_xlabel("Peak Memory (MB)")
axs[1].set_ylabel("Relative Step Time")
axs[1].set_title("Step Time vs. Peak Memory")
axs[1].set_ylim(0.8, 1.5)
# save plot
fig.savefig(f"{args.model}_benchmark.png")
def auto_activation_checkpoint_benchmark(args):
world_size = 1
run_func_module = partial(_benchmark, world_size=world_size, port=free_port(), args=args)
mp.spawn(run_func_module, nprocs=world_size)
if __name__ == "__main__":
parser = ArgumentParser("Auto Activation Checkpoint Solver Benchmark")
parser.add_argument("--model", type=str, default='gpt2', choices=['gpt2', 'resnet50'])
args = parser.parse_args()
auto_activation_checkpoint_benchmark(args)