|
|
@ -5,8 +5,12 @@ from typing import Any, List
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
from torch.fx import Graph, Node
|
|
|
|
from torch.fx import Graph, Node
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from colossalai.auto_parallel.passes.runtime_apply_pass import (
|
|
|
|
|
|
|
|
runtime_apply,
|
|
|
|
|
|
|
|
runtime_apply_for_iterable_object,
|
|
|
|
|
|
|
|
runtime_comm_spec_apply,
|
|
|
|
|
|
|
|
)
|
|
|
|
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
|
|
|
|
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
|
|
|
|
from colossalai.fx.profiler.memory_utils import is_inplace
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all___ = ['CheckpointSolverBase']
|
|
|
|
__all___ = ['CheckpointSolverBase']
|
|
|
|
|
|
|
|
|
|
|
@ -131,7 +135,23 @@ class CheckpointSolverBase(ABC):
|
|
|
|
bool
|
|
|
|
bool
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users))
|
|
|
|
def _is_inplace(n: Node):
|
|
|
|
|
|
|
|
"""Get the inplace argument from torch.fx.Node
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
inplace = False
|
|
|
|
|
|
|
|
if n.op == "call_function":
|
|
|
|
|
|
|
|
inplace = n.kwargs.get("inplace", False)
|
|
|
|
|
|
|
|
elif n.op == "call_module":
|
|
|
|
|
|
|
|
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
|
|
|
|
|
|
|
|
return inplace
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _is_shape_consistency(n: Node):
|
|
|
|
|
|
|
|
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``)
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any(
|
|
|
|
|
|
|
|
map(_is_shape_consistency, n.users))
|
|
|
|
|
|
|
|
|
|
|
|
# make sure that item in cnode is valid
|
|
|
|
# make sure that item in cnode is valid
|
|
|
|
if self.cnode:
|
|
|
|
if self.cnode:
|
|
|
|