diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py index 63eff31b2..ecccef8d7 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py @@ -5,8 +5,12 @@ from typing import Any, List import torch from torch.fx import Graph, Node +from colossalai.auto_parallel.passes.runtime_apply_pass import ( + runtime_apply, + runtime_apply_for_iterable_object, + runtime_comm_spec_apply, +) from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen -from colossalai.fx.profiler.memory_utils import is_inplace __all___ = ['CheckpointSolverBase'] @@ -131,7 +135,23 @@ class CheckpointSolverBase(ABC): bool """ - return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users)) + def _is_inplace(n: Node): + """Get the inplace argument from torch.fx.Node + """ + inplace = False + if n.op == "call_function": + inplace = n.kwargs.get("inplace", False) + elif n.op == "call_module": + inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) + return inplace + + def _is_shape_consistency(n: Node): + """Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``) + """ + return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply] + + return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any( + map(_is_shape_consistency, n.users)) # make sure that item in cnode is valid if self.cnode: