diff --git a/colossalai/fx/passes/algorithms/__init__.py b/colossalai/fx/passes/algorithms/__init__.py index 943fbd867..bf6f9eb28 100644 --- a/colossalai/fx/passes/algorithms/__init__.py +++ b/colossalai/fx/passes/algorithms/__init__.py @@ -1 +1 @@ -from .ckpt_solver_chen import chen_greedy, chen_sqrtn +from .ckpt_solver_chen import chen_greedy diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index 8b404e3a6..5f665aae5 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -1,16 +1,33 @@ from typing import List, Set, Tuple import torch -from torch.fx import GraphModule +from torch.fx import GraphModule, Node import math -__all__ = ['chen_greedy', 'chen_sqrtn'] +__all__ = ['chen_greedy'] +CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr'] def _all_potential_ckpt_nodes(gm: GraphModule) -> List: + """ + In most existing frameworks of activation checkpoint, the forward graph is assumed to be linearized. + """ + + def is_sink(): + """ + If we can free all memories when executing a certain node, it is a sink. + """ + return not sum((v for k, v in deps.items())) + + deps = {} ckpt_nodes = [] for n in gm.graph.nodes: - if n.op == 'call_module': + for n_par in n._input_nodes: + deps[n_par] -= 1 # free memory and dependencies + + # We can only put act_ckpt on these nodes + if n.op in CKPT_OP and is_sink(): ckpt_nodes.append(n) + deps[n] = len(n.users) # add dependencies for future executions return ckpt_nodes @@ -71,32 +88,7 @@ def chen_greedy(gm: GraphModule) -> GraphModule: for i, seg in enumerate(ckpt): for idx in range(*seg): n = node_list[idx] - if n.op in ['call_module', 'call_method', 'call_function']: - setattr(n, 'activation_checkpoint', str(i)) - gm.recompile() - return gm - - -def chen_sqrtn(gm: GraphModule) -> GraphModule: - """ - This is the theoretical optimal strategy in https://arxiv.org/abs/1604.06174. - - Usage: - model = resnet18() - input_sample = torch.rand(4, 3, 224, 224) - gm = symbolic_trace(model) - MetaInfoProp(gm).run(input_sample) - gm = chen_sqrtn(gm) - - Args: - gm (GraphModule): The module to add checkpoints - """ - gm.graph.lint() # make sure nodes are in topological order - k = int(len(gm.graph.nodes)**0.5) # take approximately sqrt(n) checkpoints - for idx, n in enumerate(gm.graph.nodes): - # We should not add act_ckpt to the placeholder - # The last segment should not be checkpointed - if n.op != 'placeholder' and (idx + 1) // k < k: - setattr(n, 'activation_checkpoint', str((idx + 1) // k)) + if n.op in CKPT_OP: + setattr(n, 'activation_checkpoint', i) gm.recompile() return gm diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index 1772c2840..e57fa5f12 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -1,5 +1,6 @@ from typing import Callable import copy +import re import torch import torch.multiprocessing as mp import torchvision.models as tm @@ -7,7 +8,7 @@ from torch.fx import GraphModule import colossalai from colossalai.fx import ColoTracer from colossalai.fx.passes.meta_info_prop import MetaInfoProp -from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn +from colossalai.fx.passes.algorithms import chen_greedy from colossalai.utils import free_port from colossalai.core import global_context as gpc import pytest @@ -20,7 +21,7 @@ except: from colossalai.fx.codegen import python_code_with_activation_checkpoint with_codegen = False -SOLVERS = [chen_greedy, chen_sqrtn] +SOLVERS = [chen_greedy] def _is_activation_checkpoint_available(gm: GraphModule): @@ -36,6 +37,16 @@ def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule): return True +def _is_graph_linearized(gm: GraphModule): + code = gm.code + # find patterns like r' return output_1, output_2', which is not expected on a linearized graph + pattern = re.compile(r' return [a-zA-Z0-9_]+(, [a-zA-Z0-9_]+)+') + if pattern.findall(code): + return False + else: + return True + + def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule], model_cls: Callable[[], torch.nn.Module]): criterion = torch.nn.MSELoss() @@ -66,12 +77,13 @@ def _run_ckpt_solver(rank): codegen = ActivationCheckpointCodeGen() gm.graph.set_codegen(codegen) gm = solver(gm) + assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." assert _is_activation_checkpoint_available( gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" check_backward_consistency(m, gm, solver, model_cls) + gpc.destroy() -@pytest.mark.skip @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') def test_ckpt_solver(): mp.spawn(_run_ckpt_solver, nprocs=1) @@ -94,12 +106,13 @@ def _run_ckpt_solver_torch11(rank): MetaInfoProp(gm).run(data) gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph) gm = solver(gm) + assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner." assert _is_activation_checkpoint_available( gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" check_backward_consistency(m, gm, solver, model_cls) + gpc.destroy() -@pytest.mark.skip @pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') def test_ckpt_solver_torch11(): mp.spawn(_run_ckpt_solver_torch11, nprocs=1)