[autockpt] linearize / merge shape-consistency nodes. (#2271)

* [autockpt] make it work.

* [autockpt] linearize / merge shape-consistency nodes.
pull/2279/head^2
Super Daniel 2 years ago committed by GitHub
parent 5c2ef9fc76
commit b0d21d0c4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,8 +5,12 @@ from typing import Any, List
import torch import torch
from torch.fx import Graph, Node from torch.fx import Graph, Node
from colossalai.auto_parallel.passes.runtime_apply_pass import (
runtime_apply,
runtime_apply_for_iterable_object,
runtime_comm_spec_apply,
)
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
from colossalai.fx.profiler.memory_utils import is_inplace
__all___ = ['CheckpointSolverBase'] __all___ = ['CheckpointSolverBase']
@ -131,7 +135,23 @@ class CheckpointSolverBase(ABC):
bool bool
""" """
return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users)) def _is_inplace(n: Node):
"""Get the inplace argument from torch.fx.Node
"""
inplace = False
if n.op == "call_function":
inplace = n.kwargs.get("inplace", False)
elif n.op == "call_module":
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
return inplace
def _is_shape_consistency(n: Node):
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)
"""
return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply]
return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any(
map(_is_shape_consistency, n.users))
# make sure that item in cnode is valid # make sure that item in cnode is valid
if self.cnode: if self.cnode:

Loading…
Cancel
Save