From d40a9392ba2b7d60a1874b41b008543cd9ad80b0 Mon Sep 17 00:00:00 2001 From: Super Daniel <78588128+super-dainiu@users.noreply.github.com> Date: Fri, 12 Aug 2022 11:28:50 +0800 Subject: [PATCH] [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. (#1446) * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * mend * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen. * [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. * [fx] fix lowercase naming conventions. --- .../fx/passes/algorithms/ckpt_solver_chen.py | 64 +++++++++++++------ .../test_ckpt_torchvision.py | 20 +++++- 2 files changed, 62 insertions(+), 22 deletions(-) diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index d28e6fa1a..046b165a6 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -1,45 +1,71 @@ +from typing import Set, Tuple import torch from torch.fx import GraphModule +import math __all__ = ['chen_greedy', 'chen_sqrtn'] -def chen_greedy(gm: GraphModule, B: int): +def chen_greedy(gm: GraphModule) -> GraphModule: """ This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. + Note that this algorithm targets at memory optimization only, using techniques in appendix A. Usage: - B = 5 * 1024 * 1024 * 1024 # An approximate memory budget of 5GB model = resnet18() input_sample = torch.rand(4, 3, 224, 224) gm = symbolic_trace(model) MetaInfoProp(gm).run(input_sample) - gm = chen_greedy(gm, B) + gm = chen_greedy(gm) Args: gm (GraphModule): The module to add checkpoints - B (int): The approximate memory budget for this module. """ + + def grid_search(num_grids: int = 6) -> Set: + """ + Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy. + Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A. + """ + _, b_approx = run_chen_greedy(0) + b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2)) + b_opt = math.inf + for b in range(b_min, b_max, (b_max - b_min) // num_grids): + ckpt, b_approx = run_chen_greedy(b) + if b_approx < b_opt: + b_opt = b_approx + ckpt_opt = ckpt + return ckpt_opt + + def run_chen_greedy(b: int = 0) -> Tuple[Set, int]: + """ + This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. + """ + ckpt = set() + temp = 0 + x = 0 + y = 0 + for (idx, n) in enumerate(gm.graph.nodes): + temp += getattr(n, 'activation_size') + y = max(y, temp) + if temp > b: + x += getattr(n, 'activation_size') + temp = 0 + ckpt.add(idx) + return ckpt, math.floor(math.sqrt(x * y)) + gm.graph.lint() # make sure nodes are in topological order - temp = 0 - x = 0 - idx = 0 - budget = B - for n in gm.graph.nodes: - B -= getattr(n, 'param_size') - assert B > 0, f'The memory budget {budget / 1024 ** 3:.2f} GB is not enough for model parameters of {gm}' - for n in gm.graph.nodes: - temp += getattr(n, 'activation_size') - if temp > B: - x += getattr(n, 'activation_size') - temp = x - setattr(n, 'activation_checkpoint', str(idx)) - idx += 1 + ckpt = grid_search(num_grids=6) + i = 0 + for idx, n in enumerate(gm.graph.nodes): + if idx in ckpt: + setattr(n, 'activation_checkpoint', str(i)) + i += 1 gm.recompile() return gm -def chen_sqrtn(gm: GraphModule): +def chen_sqrtn(gm: GraphModule) -> GraphModule: """ This is the theoretical optimal strategy in https://arxiv.org/abs/1604.06174. diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index 4bf3128c6..169b4bcb6 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -1,13 +1,13 @@ +from ctypes import Union from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn import torch import torchvision.models as tm from colossalai.fx import ColoTracer from torch.fx import GraphModule from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from functools import partial import pytest -SOLVERS = [partial(chen_greedy, B=1024 * 1024 * 64), chen_sqrtn] +SOLVERS = [chen_greedy, chen_sqrtn] def _is_activation_checkpoint_available(gm: GraphModule): @@ -16,6 +16,13 @@ def _is_activation_checkpoint_available(gm: GraphModule): return True +def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule): + for m_p, gm_p in zip(m.parameters(), gm.parameters()): + if not torch.allclose(m_p, gm_p): + return False + return True + + def test_ckpt_solver(): MODEL_LIST = [tm.resnet18, tm.densenet121] @@ -23,17 +30,24 @@ def test_ckpt_solver(): tracer = ColoTracer() data = torch.rand(1, 3, 224, 224) + label = torch.rand(1, 1000) for solver in SOLVERS: for model_cls in MODEL_LIST: model = model_cls() + criterion = torch.nn.MSELoss() graph = tracer.trace(root=model) gm = GraphModule(model, graph, model.__class__.__name__) MetaInfoProp(gm).run(data) gm = solver(gm) assert _is_activation_checkpoint_available( gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" - assert torch.allclose(gm(data), model(data)) + loss = criterion(model(data), label) + loss.backward() + loss = criterion(gm(data), label) + loss.backward() + assert _is_all_gradient_close(model, + gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}' if __name__ == '__main__':