From b0d21d0c4f56fa1d54d0f53a0020802a75441909 Mon Sep 17 00:00:00 2001 From: Super Daniel <78588128+super-dainiu@users.noreply.github.com> Date: Tue, 3 Jan 2023 14:54:22 +0800 Subject: [PATCH] [autockpt] linearize / merge shape-consistency nodes. (#2271) * [autockpt] make it work. * [autockpt] linearize / merge shape-consistency nodes. --- .../checkpoint/ckpt_solver_base.py | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py index 63eff31b2..ecccef8d7 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py @@ -5,8 +5,12 @@ from typing import Any, List import torch from torch.fx import Graph, Node +from colossalai.auto_parallel.passes.runtime_apply_pass import ( + runtime_apply, + runtime_apply_for_iterable_object, + runtime_comm_spec_apply, +) from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen -from colossalai.fx.profiler.memory_utils import is_inplace __all___ = ['CheckpointSolverBase'] @@ -131,7 +135,23 @@ class CheckpointSolverBase(ABC): bool """ - return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users)) + def _is_inplace(n: Node): + """Get the inplace argument from torch.fx.Node + """ + inplace = False + if n.op == "call_function": + inplace = n.kwargs.get("inplace", False) + elif n.op == "call_module": + inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) + return inplace + + def _is_shape_consistency(n: Node): + """Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``) + """ + return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply] + + return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any( + map(_is_shape_consistency, n.users)) # make sure that item in cnode is valid if self.cnode: