|
|
|
@ -5,8 +5,12 @@ from typing import Any, List
|
|
|
|
|
import torch |
|
|
|
|
from torch.fx import Graph, Node |
|
|
|
|
|
|
|
|
|
from colossalai.auto_parallel.passes.runtime_apply_pass import ( |
|
|
|
|
runtime_apply, |
|
|
|
|
runtime_apply_for_iterable_object, |
|
|
|
|
runtime_comm_spec_apply, |
|
|
|
|
) |
|
|
|
|
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen |
|
|
|
|
from colossalai.fx.profiler.memory_utils import is_inplace |
|
|
|
|
|
|
|
|
|
__all___ = ['CheckpointSolverBase'] |
|
|
|
|
|
|
|
|
@ -131,7 +135,23 @@ class CheckpointSolverBase(ABC):
|
|
|
|
|
bool |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
return not sum([v for _, v in deps.items()]) and not any(map(is_inplace, n.users)) |
|
|
|
|
def _is_inplace(n: Node): |
|
|
|
|
"""Get the inplace argument from torch.fx.Node |
|
|
|
|
""" |
|
|
|
|
inplace = False |
|
|
|
|
if n.op == "call_function": |
|
|
|
|
inplace = n.kwargs.get("inplace", False) |
|
|
|
|
elif n.op == "call_module": |
|
|
|
|
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) |
|
|
|
|
return inplace |
|
|
|
|
|
|
|
|
|
def _is_shape_consistency(n: Node): |
|
|
|
|
"""Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``) |
|
|
|
|
""" |
|
|
|
|
return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply] |
|
|
|
|
|
|
|
|
|
return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any( |
|
|
|
|
map(_is_shape_consistency, n.users)) |
|
|
|
|
|
|
|
|
|
# make sure that item in cnode is valid |
|
|
|
|
if self.cnode: |
|
|
|
|