mirror of https://github.com/hpcaitech/ColossalAI
[sc] add examples for auto checkpoint. (#1880)
parent
51597f6a28
commit
6d559ea614
File diff suppressed because one or more lines are too long
|
@ -0,0 +1,65 @@
|
|||
import time
|
||||
from functools import partial
|
||||
from typing import Callable, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.models as tm
|
||||
|
||||
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
|
||||
from colossalai.fx import metainfo_trace
|
||||
|
||||
|
||||
def bench(gm: torch.fx.GraphModule, criterion: torch.nn.Module, data_gen: Callable, num_steps: int = 5):
|
||||
gm.train()
|
||||
gm.cuda()
|
||||
step_time = float('inf')
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
cached = torch.cuda.max_memory_allocated(device="cuda")
|
||||
try:
|
||||
for _ in range(num_steps):
|
||||
args, label = data_gen()
|
||||
output, loss = None, None
|
||||
|
||||
torch.cuda.synchronize(device="cuda")
|
||||
start = time.time()
|
||||
output = gm(*args)
|
||||
loss = criterion(output, label)
|
||||
loss.backward()
|
||||
torch.cuda.synchronize(device="cuda")
|
||||
step_time = min(step_time, time.time() - start)
|
||||
|
||||
for child in gm.children():
|
||||
for param in child.parameters():
|
||||
param.grad = None
|
||||
del args, label, output, loss
|
||||
except:
|
||||
del args, label, output, loss
|
||||
gm.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
return (torch.cuda.max_memory_allocated(device="cuda") - cached) / 1024**2, step_time * 1.0e3
|
||||
|
||||
|
||||
def bench_rotor(gm: torch.fx.GraphModule,
|
||||
criterion: torch.nn.Module,
|
||||
data_gen: Callable,
|
||||
num_steps: int = 5,
|
||||
sample_points: int = 20,
|
||||
free_memory: int = torch.cuda.mem_get_info()[0]):
|
||||
peak_hist, step_hist = [], []
|
||||
for budget in np.linspace(free_memory // 5, free_memory, sample_points):
|
||||
gm = metainfo_trace(gm, *data_gen()[0])
|
||||
solver = CheckpointSolverRotor(gm.graph, free_memory=budget)
|
||||
try:
|
||||
gm.graph = solver.solve()
|
||||
peak_memory, step_time = bench(gm,
|
||||
criterion,
|
||||
partial(data_gen, batch_size=2048, shape=(3, 224, 224)),
|
||||
num_steps=num_steps)
|
||||
except:
|
||||
peak_memory, step_time = budget / 1024**2, float('inf')
|
||||
peak_hist.append(peak_memory)
|
||||
step_hist.append(step_time)
|
||||
return peak_hist, step_hist
|
Loading…
Reference in New Issue