mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen.pull/1441/head
Super Daniel
2 years ago
committed by
GitHub
3 changed files with 103 additions and 0 deletions
@ -0,0 +1 @@
|
||||
from .ckpt_solver_chen import chen_greedy, chen_sqrtn |
@ -0,0 +1,62 @@
|
||||
import torch |
||||
from torch.fx import GraphModule |
||||
|
||||
__all__ = ['chen_greedy', 'chen_sqrtn'] |
||||
|
||||
|
||||
def chen_greedy(gm: GraphModule, B: int): |
||||
""" |
||||
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. |
||||
|
||||
Usage: |
||||
B = 5 * 1024 * 1024 * 1024 # An approximate memory budget of 5GB |
||||
model = resnet18() |
||||
input_sample = torch.rand(4, 3, 224, 224) |
||||
gm = symbolic_trace(model) |
||||
MetaInfoProp(gm).run(input_sample) |
||||
gm = chen_greedy(gm, B) |
||||
|
||||
Args: |
||||
gm (GraphModule): The module to add checkpoints |
||||
B (int): The approximate memory budget for this module. |
||||
""" |
||||
gm.graph.lint() # make sure nodes are in topological order |
||||
temp = 0 |
||||
x = 0 |
||||
idx = 0 |
||||
budget = B |
||||
for n in gm.graph.nodes: |
||||
B -= getattr(n, 'param_size') |
||||
assert B > 0, f'The memory budget {budget / 1024 ** 3:.2f} GB is not enough for model parameters of {gm}' |
||||
for n in gm.graph.nodes: |
||||
temp += getattr(n, 'activation_size') |
||||
if temp > B: |
||||
x += getattr(n, 'activation_size') |
||||
temp = x |
||||
setattr(n, 'activation_checkpoint', str(idx)) |
||||
idx += 1 |
||||
gm.recompile() |
||||
return gm |
||||
|
||||
|
||||
def chen_sqrtn(gm: GraphModule): |
||||
""" |
||||
This is the theoretical optimal strategy in https://arxiv.org/abs/1604.06174. |
||||
|
||||
Usage: |
||||
model = resnet18() |
||||
input_sample = torch.rand(4, 3, 224, 224) |
||||
gm = symbolic_trace(model) |
||||
MetaInfoProp(gm).run(input_sample) |
||||
gm = chen_sqrtn(gm) |
||||
|
||||
Args: |
||||
gm (GraphModule): The module to add checkpoints |
||||
""" |
||||
gm.graph.lint() # make sure nodes are in topological order |
||||
k = int(len(gm.graph.nodes)**0.5) # take approximately sqrt(n) checkpoints |
||||
for idx, n in enumerate(gm.graph.nodes): |
||||
if (idx + 1) % k == 0: |
||||
setattr(n, 'activation_checkpoint', str((idx + 1) // k)) |
||||
gm.recompile() |
||||
return gm |
@ -0,0 +1,40 @@
|
||||
from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn |
||||
import torch |
||||
import torchvision.models as tm |
||||
from colossalai.fx import ColoTracer |
||||
from torch.fx import GraphModule |
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp |
||||
from functools import partial |
||||
import pytest |
||||
|
||||
SOLVERS = [partial(chen_greedy, B=1024 * 1024 * 64), chen_sqrtn] |
||||
|
||||
|
||||
def _is_activation_checkpoint_available(gm: GraphModule): |
||||
for n in gm.graph.nodes: |
||||
if hasattr(n, 'activation_checkpoint') and getattr(n, 'activation_checkpoint') is not None: |
||||
return True |
||||
|
||||
|
||||
def test_ckpt_solver(): |
||||
MODEL_LIST = [tm.resnet18, tm.densenet121] |
||||
|
||||
torch.backends.cudnn.deterministic = True |
||||
|
||||
tracer = ColoTracer() |
||||
data = torch.rand(1, 3, 224, 224) |
||||
|
||||
for solver in SOLVERS: |
||||
for model_cls in MODEL_LIST: |
||||
model = model_cls() |
||||
graph = tracer.trace(root=model) |
||||
gm = GraphModule(model, graph, model.__class__.__name__) |
||||
MetaInfoProp(gm).run(data) |
||||
gm = solver(gm) |
||||
assert _is_activation_checkpoint_available( |
||||
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" |
||||
assert torch.allclose(gm(data), model(data)) |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
test_ckpt_solver() |
Loading…
Reference in new issue