diff --git a/colossalai/fx/passes/algorithms/__init__.py b/colossalai/fx/passes/algorithms/__init__.py new file mode 100644 index 000000000..943fbd867 --- /dev/null +++ b/colossalai/fx/passes/algorithms/__init__.py @@ -0,0 +1 @@ +from .ckpt_solver_chen import chen_greedy, chen_sqrtn diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py new file mode 100644 index 000000000..d28e6fa1a --- /dev/null +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -0,0 +1,62 @@ +import torch +from torch.fx import GraphModule + +__all__ = ['chen_greedy', 'chen_sqrtn'] + + +def chen_greedy(gm: GraphModule, B: int): + """ + This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. + + Usage: + B = 5 * 1024 * 1024 * 1024 # An approximate memory budget of 5GB + model = resnet18() + input_sample = torch.rand(4, 3, 224, 224) + gm = symbolic_trace(model) + MetaInfoProp(gm).run(input_sample) + gm = chen_greedy(gm, B) + + Args: + gm (GraphModule): The module to add checkpoints + B (int): The approximate memory budget for this module. + """ + gm.graph.lint() # make sure nodes are in topological order + temp = 0 + x = 0 + idx = 0 + budget = B + for n in gm.graph.nodes: + B -= getattr(n, 'param_size') + assert B > 0, f'The memory budget {budget / 1024 ** 3:.2f} GB is not enough for model parameters of {gm}' + for n in gm.graph.nodes: + temp += getattr(n, 'activation_size') + if temp > B: + x += getattr(n, 'activation_size') + temp = x + setattr(n, 'activation_checkpoint', str(idx)) + idx += 1 + gm.recompile() + return gm + + +def chen_sqrtn(gm: GraphModule): + """ + This is the theoretical optimal strategy in https://arxiv.org/abs/1604.06174. + + Usage: + model = resnet18() + input_sample = torch.rand(4, 3, 224, 224) + gm = symbolic_trace(model) + MetaInfoProp(gm).run(input_sample) + gm = chen_sqrtn(gm) + + Args: + gm (GraphModule): The module to add checkpoints + """ + gm.graph.lint() # make sure nodes are in topological order + k = int(len(gm.graph.nodes)**0.5) # take approximately sqrt(n) checkpoints + for idx, n in enumerate(gm.graph.nodes): + if (idx + 1) % k == 0: + setattr(n, 'activation_checkpoint', str((idx + 1) // k)) + gm.recompile() + return gm diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py new file mode 100644 index 000000000..4bf3128c6 --- /dev/null +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -0,0 +1,40 @@ +from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn +import torch +import torchvision.models as tm +from colossalai.fx import ColoTracer +from torch.fx import GraphModule +from colossalai.fx.passes.meta_info_prop import MetaInfoProp +from functools import partial +import pytest + +SOLVERS = [partial(chen_greedy, B=1024 * 1024 * 64), chen_sqrtn] + + +def _is_activation_checkpoint_available(gm: GraphModule): + for n in gm.graph.nodes: + if hasattr(n, 'activation_checkpoint') and getattr(n, 'activation_checkpoint') is not None: + return True + + +def test_ckpt_solver(): + MODEL_LIST = [tm.resnet18, tm.densenet121] + + torch.backends.cudnn.deterministic = True + + tracer = ColoTracer() + data = torch.rand(1, 3, 224, 224) + + for solver in SOLVERS: + for model_cls in MODEL_LIST: + model = model_cls() + graph = tracer.trace(root=model) + gm = GraphModule(model, graph, model.__class__.__name__) + MetaInfoProp(gm).run(data) + gm = solver(gm) + assert _is_activation_checkpoint_available( + gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" + assert torch.allclose(gm(data), model(data)) + + +if __name__ == '__main__': + test_ckpt_solver()