mirror of https://github.com/hpcaitech/ColossalAI
[autockpt] linearize / merge shape-consistency nodes. (#2271)
* [autockpt] make it work. * [autockpt] linearize / merge shape-consistency nodes.pull/2279/head^2
parent
5c2ef9fc76
commit
b0d21d0c4f
|
@ -5,8 +5,12 @@ from typing import Any, List
|
|||
import torch
|
||||
from torch.fx import Graph, Node
|
||||
|
||||
from colossalai.auto_parallel.passes.runtime_apply_pass import (
|
||||
runtime_apply,
|
||||
runtime_apply_for_iterable_object,
|
||||
runtime_comm_spec_apply,
|
||||
)
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
|
||||
from colossalai.fx.profiler.memory_utils import is_inplace
|
||||
|
||||
__all___ = ['CheckpointSolverBase']
|
||||
|
||||
|
@ -131,7 +135,23 @@ class CheckpointSolverBase(ABC):
|
|||
bool
|
||||
"""
|
||||
|
||||
return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
|
||||
def _is_inplace(n: Node):
|
||||
"""Get the inplace argument from torch.fx.Node
|
||||
"""
|
||||
inplace = False
|
||||
if n.op == "call_function":
|
||||
inplace = n.kwargs.get("inplace", False)
|
||||
elif n.op == "call_module":
|
||||
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
|
||||
return inplace
|
||||
|
||||
def _is_shape_consistency(n: Node):
|
||||
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)
|
||||
"""
|
||||
return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply]
|
||||
|
||||
return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any(
|
||||
map(_is_shape_consistency, n.users))
|
||||
|
||||
# make sure that item in cnode is valid
|
||||
if self.cnode:
|
||||
|
|
Loading…
Reference in New Issue