mirror of https://github.com/hpcaitech/ColossalAI
[fx] add vanilla activation checkpoint search with test on resnet and densenet (#1433)
* [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen.pull/1441/head
parent
30b4dd17c0
commit
3b26516c69
|
@ -0,0 +1 @@
|
||||||
|
from .ckpt_solver_chen import chen_greedy, chen_sqrtn
|
|
@ -0,0 +1,62 @@
|
||||||
|
import torch
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
|
||||||
|
__all__ = ['chen_greedy', 'chen_sqrtn']
|
||||||
|
|
||||||
|
|
||||||
|
def chen_greedy(gm: GraphModule, B: int):
|
||||||
|
"""
|
||||||
|
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
B = 5 * 1024 * 1024 * 1024 # An approximate memory budget of 5GB
|
||||||
|
model = resnet18()
|
||||||
|
input_sample = torch.rand(4, 3, 224, 224)
|
||||||
|
gm = symbolic_trace(model)
|
||||||
|
MetaInfoProp(gm).run(input_sample)
|
||||||
|
gm = chen_greedy(gm, B)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gm (GraphModule): The module to add checkpoints
|
||||||
|
B (int): The approximate memory budget for this module.
|
||||||
|
"""
|
||||||
|
gm.graph.lint() # make sure nodes are in topological order
|
||||||
|
temp = 0
|
||||||
|
x = 0
|
||||||
|
idx = 0
|
||||||
|
budget = B
|
||||||
|
for n in gm.graph.nodes:
|
||||||
|
B -= getattr(n, 'param_size')
|
||||||
|
assert B > 0, f'The memory budget {budget / 1024 ** 3:.2f} GB is not enough for model parameters of {gm}'
|
||||||
|
for n in gm.graph.nodes:
|
||||||
|
temp += getattr(n, 'activation_size')
|
||||||
|
if temp > B:
|
||||||
|
x += getattr(n, 'activation_size')
|
||||||
|
temp = x
|
||||||
|
setattr(n, 'activation_checkpoint', str(idx))
|
||||||
|
idx += 1
|
||||||
|
gm.recompile()
|
||||||
|
return gm
|
||||||
|
|
||||||
|
|
||||||
|
def chen_sqrtn(gm: GraphModule):
|
||||||
|
"""
|
||||||
|
This is the theoretical optimal strategy in https://arxiv.org/abs/1604.06174.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
model = resnet18()
|
||||||
|
input_sample = torch.rand(4, 3, 224, 224)
|
||||||
|
gm = symbolic_trace(model)
|
||||||
|
MetaInfoProp(gm).run(input_sample)
|
||||||
|
gm = chen_sqrtn(gm)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gm (GraphModule): The module to add checkpoints
|
||||||
|
"""
|
||||||
|
gm.graph.lint() # make sure nodes are in topological order
|
||||||
|
k = int(len(gm.graph.nodes)**0.5) # take approximately sqrt(n) checkpoints
|
||||||
|
for idx, n in enumerate(gm.graph.nodes):
|
||||||
|
if (idx + 1) % k == 0:
|
||||||
|
setattr(n, 'activation_checkpoint', str((idx + 1) // k))
|
||||||
|
gm.recompile()
|
||||||
|
return gm
|
|
@ -0,0 +1,40 @@
|
||||||
|
from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn
|
||||||
|
import torch
|
||||||
|
import torchvision.models as tm
|
||||||
|
from colossalai.fx import ColoTracer
|
||||||
|
from torch.fx import GraphModule
|
||||||
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
|
from functools import partial
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
SOLVERS = [partial(chen_greedy, B=1024 * 1024 * 64), chen_sqrtn]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_activation_checkpoint_available(gm: GraphModule):
|
||||||
|
for n in gm.graph.nodes:
|
||||||
|
if hasattr(n, 'activation_checkpoint') and getattr(n, 'activation_checkpoint') is not None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_ckpt_solver():
|
||||||
|
MODEL_LIST = [tm.resnet18, tm.densenet121]
|
||||||
|
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
tracer = ColoTracer()
|
||||||
|
data = torch.rand(1, 3, 224, 224)
|
||||||
|
|
||||||
|
for solver in SOLVERS:
|
||||||
|
for model_cls in MODEL_LIST:
|
||||||
|
model = model_cls()
|
||||||
|
graph = tracer.trace(root=model)
|
||||||
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
|
MetaInfoProp(gm).run(data)
|
||||||
|
gm = solver(gm)
|
||||||
|
assert _is_activation_checkpoint_available(
|
||||||
|
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
|
||||||
|
assert torch.allclose(gm(data), model(data))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_ckpt_solver()
|
Loading…
Reference in New Issue