mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
66 lines
2.3 KiB
66 lines
2.3 KiB
2 years ago
|
import time
|
||
|
from functools import partial
|
||
|
from typing import Callable, Tuple
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torchvision.models as tm
|
||
|
|
||
|
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
|
||
|
from colossalai.fx import metainfo_trace
|
||
|
|
||
|
|
||
|
def bench(gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callable, num_steps: int = 5):
|
||
|
gm.train()
|
||
|
gm.cuda()
|
||
|
step_time = float('inf')
|
||
|
torch.cuda.synchronize()
|
||
|
torch.cuda.empty_cache()
|
||
|
torch.cuda.reset_peak_memory_stats()
|
||
|
cached = torch.cuda.max_memory_allocated(device="cuda")
|
||
|
try:
|
||
|
for _ in range(num_steps):
|
||
|
args, label = data_gen()
|
||
|
output, loss = None, None
|
||
|
|
||
|
torch.cuda.synchronize(device="cuda")
|
||
|
start = time.time()
|
||
|
output = gm(*args)
|
||
|
loss = criterion(output, label)
|
||
|
loss.backward()
|
||
|
torch.cuda.synchronize(device="cuda")
|
||
|
step_time = min(step_time, time.time() - start)
|
||
|
|
||
|
for child in gm.children():
|
||
|
for param in child.parameters():
|
||
|
param.grad = None
|
||
|
del args, label, output, loss
|
||
|
except:
|
||
|
del args, label, output, loss
|
||
|
gm.to("cpu")
|
||
|
torch.cuda.empty_cache()
|
||
|
return (torch.cuda.max_memory_allocated(device="cuda") - cached) / 1024**2, step_time * 1.0e3
|
||
|
|
||
|
|
||
|
def bench_rotor(gm: torch.fx.GraphModule,
|
||
|
criterion: torch.nn.Module,
|
||
|
data_gen: Callable,
|
||
|
num_steps: int = 5,
|
||
|
sample_points: int = 20,
|
||
|
free_memory: int = torch.cuda.mem_get_info()[0]):
|
||
|
peak_hist, step_hist = [], []
|
||
|
for budget in np.linspace(free_memory // 5, free_memory, sample_points):
|
||
|
gm = metainfo_trace(gm, *data_gen()[0])
|
||
|
solver = CheckpointSolverRotor(gm.graph, free_memory=budget)
|
||
|
try:
|
||
|
gm.graph = solver.solve()
|
||
|
peak_memory, step_time = bench(gm,
|
||
|
criterion,
|
||
|
partial(data_gen, batch_size=2048, shape=(3, 224, 224)),
|
||
|
num_steps=num_steps)
|
||
|
except:
|
||
|
peak_memory, step_time = budget / 1024**2, float('inf')
|
||
|
peak_hist.append(peak_memory)
|
||
|
step_hist.append(step_time)
|
||
|
return peak_hist, step_hist
|