mirror of https://github.com/hpcaitech/ColossalAI
[fx] refactor memory utils and extend shard utils. (#1754)
* [fx] change memory.py to memory_utils.py. * [fx] add shard utils. * [fx] fix import. * [fx] check code style. * [fx] add comment. * [autoparallel] first move. * [fx] add time computations.pull/1766/head
parent
63f250bbd4
commit
0584654c79
|
@ -1,7 +1,9 @@
|
||||||
|
import math
|
||||||
from typing import List, Set, Tuple
|
from typing import List, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import GraphModule, Node
|
from torch.fx import GraphModule, Node
|
||||||
import math
|
|
||||||
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
|
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp
|
||||||
|
|
||||||
__all__ = ['chen_greedy']
|
__all__ = ['chen_greedy']
|
||||||
|
|
|
@ -1,15 +1,17 @@
|
||||||
|
import math
|
||||||
import sys
|
import sys
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
from colossalai.fx.profiler.memory import calculate_fwd_in
|
|
||||||
from torch.fx import Node
|
from torch.fx import Node
|
||||||
from colossalai.fx.graph_module import ColoGraphModule
|
|
||||||
from colossalai.fx.profiler import activation_size, parameter_size, calculate_fwd_out, calculate_fwd_tmp
|
|
||||||
import math
|
|
||||||
from .linearize import linearize
|
|
||||||
from .operation import ForwardCheck, ForwardEnable, ForwardNograd, Backward, Loss, Chain, Sequence, Function
|
|
||||||
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
from colossalai.fx.codegen.activation_checkpoint_codegen import _find_nested_ckpt_regions
|
||||||
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
|
from colossalai.fx.profiler import activation_size, calculate_fwd_out, calculate_fwd_tmp, parameter_size
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
|
from .linearize import linearize
|
||||||
|
from .operation import Backward, Chain, ForwardCheck, ForwardEnable, ForwardNograd, Function, Loss, Sequence
|
||||||
|
|
||||||
# global vairable to indicate whether the solver is failed
|
# global vairable to indicate whether the solver is failed
|
||||||
SOLVER_FAILED = False
|
SOLVER_FAILED = False
|
||||||
|
|
||||||
|
@ -372,8 +374,8 @@ def solver_rotor(gm: ColoGraphModule,
|
||||||
|
|
||||||
# build module if module not found
|
# build module if module not found
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
import subprocess
|
|
||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
logger.info("dynamic_programs_C_version hasn't been built! Building library...", ranks=[0])
|
logger.info("dynamic_programs_C_version hasn't been built! Building library...", ranks=[0])
|
||||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
result = subprocess.Popen(
|
result = subprocess.Popen(
|
||||||
|
|
|
@ -3,11 +3,12 @@ from typing import Any, Dict, List, NamedTuple, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
from colossalai.fx._compatibility import compatibility
|
|
||||||
from colossalai.fx.profiler import (GraphInfo, profile_function, profile_method, profile_module)
|
|
||||||
from torch.fx.node import Argument, Node, Target
|
from torch.fx.node import Argument, Node, Target
|
||||||
from torch.utils._pytree import tree_flatten
|
from torch.utils._pytree import tree_flatten
|
||||||
|
|
||||||
|
from colossalai.fx._compatibility import compatibility
|
||||||
|
from colossalai.fx.profiler import GraphInfo, profile_function, profile_method, profile_module
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
class ConcreteInfoProp(torch.fx.Interpreter):
|
class ConcreteInfoProp(torch.fx.Interpreter):
|
||||||
|
|
|
@ -3,12 +3,21 @@ from typing import Any, Dict, List, NamedTuple, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
from colossalai.fx._compatibility import compatibility
|
|
||||||
from colossalai.fx.profiler import (GraphInfo, activation_size, calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp,
|
|
||||||
profile_function, profile_method, profile_module)
|
|
||||||
from torch.fx.node import Argument, Node, Target
|
from torch.fx.node import Argument, Node, Target
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
|
from colossalai.fx._compatibility import compatibility
|
||||||
|
from colossalai.fx.profiler import (
|
||||||
|
GraphInfo,
|
||||||
|
activation_size,
|
||||||
|
calculate_fwd_in,
|
||||||
|
calculate_fwd_out,
|
||||||
|
calculate_fwd_tmp,
|
||||||
|
profile_function,
|
||||||
|
profile_method,
|
||||||
|
profile_module,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
class TensorMetadata(NamedTuple):
|
class TensorMetadata(NamedTuple):
|
||||||
|
|
|
@ -1,12 +1,18 @@
|
||||||
from .._compatibility import is_compatible_with_meta
|
from .._compatibility import is_compatible_with_meta
|
||||||
|
|
||||||
if is_compatible_with_meta():
|
if is_compatible_with_meta():
|
||||||
from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
|
|
||||||
from .opcount import flop_mapping
|
from .opcount import flop_mapping
|
||||||
from .profiler import profile_function, profile_method, profile_module
|
from .profiler import profile_function, profile_method, profile_module
|
||||||
|
from .shard_utils import (
|
||||||
|
calculate_bwd_time,
|
||||||
|
calculate_fwd_in,
|
||||||
|
calculate_fwd_out,
|
||||||
|
calculate_fwd_time,
|
||||||
|
calculate_fwd_tmp,
|
||||||
|
)
|
||||||
from .tensor import MetaTensor
|
from .tensor import MetaTensor
|
||||||
else:
|
else:
|
||||||
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
|
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module, calculate_fwd_in, calculate_fwd_tmp, calculate_fwd_out
|
||||||
|
|
||||||
from .dataflow import GraphInfo
|
from .dataflow import GraphInfo
|
||||||
from .memory import activation_size, is_inplace, parameter_size
|
from .memory_utils import activation_size, is_inplace, parameter_size
|
||||||
|
|
|
@ -6,7 +6,7 @@ from typing import Dict, List
|
||||||
from torch.fx import Graph, Node
|
from torch.fx import Graph, Node
|
||||||
|
|
||||||
from .._compatibility import compatibility
|
from .._compatibility import compatibility
|
||||||
from .memory import activation_size, is_inplace
|
from .memory_utils import activation_size, is_inplace
|
||||||
|
|
||||||
|
|
||||||
class Phase(Enum):
|
class Phase(Enum):
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from .memory import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
|
|
||||||
from .profiler import profile_function, profile_method, profile_module
|
from .profiler import profile_function, profile_method, profile_module
|
||||||
from .profiler_function import *
|
from .profiler_function import *
|
||||||
from .profiler_module import *
|
from .profiler_module import *
|
||||||
from .registry import meta_profiler_function, meta_profiler_module
|
from .registry import meta_profiler_function, meta_profiler_module
|
||||||
|
from .shard_utils import calculate_fwd_in, calculate_fwd_out, calculate_fwd_tmp
|
||||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
||||||
from torch.fx.node import Argument, Target
|
from torch.fx.node import Argument, Target
|
||||||
|
|
||||||
from ..._compatibility import compatibility
|
from ..._compatibility import compatibility
|
||||||
from ..memory import activation_size
|
from ..memory_utils import activation_size
|
||||||
from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
|
from .constants import INPLACE_METHOD, INPLACE_OPS, NON_INPLACE_METHOD
|
||||||
from .registry import meta_profiler_function, meta_profiler_module
|
from .registry import meta_profiler_function, meta_profiler_module
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,71 @@
|
||||||
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.fx import GraphModule, Node
|
||||||
|
|
||||||
|
from .._compatibility import compatibility, is_compatible_with_meta
|
||||||
|
|
||||||
|
__all__ = ['activation_size', 'parameter_size', 'is_inplace']
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
|
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
||||||
|
"""Calculate activation size of a node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The activation size
|
||||||
|
"""
|
||||||
|
act_size = 0
|
||||||
|
if isinstance(out, torch.Tensor):
|
||||||
|
if out.is_quantized:
|
||||||
|
act_size += out.numel() * torch._empty_affine_quantized([], dtype=out.dtype).element_size()
|
||||||
|
else:
|
||||||
|
act_size += out.numel() * torch.tensor([], dtype=out.dtype).element_size()
|
||||||
|
elif isinstance(out, dict):
|
||||||
|
value_list = [v for _, v in out.items()]
|
||||||
|
act_size += activation_size(value_list)
|
||||||
|
elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set):
|
||||||
|
for element in out:
|
||||||
|
act_size += activation_size(element)
|
||||||
|
return act_size
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=True)
|
||||||
|
def parameter_size(mod: torch.nn.Module) -> int:
|
||||||
|
"""Calculate parameter size of a node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mod (torch.nn.Module): The target `torch.nn.Module`
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The parameter size
|
||||||
|
"""
|
||||||
|
param_size = 0
|
||||||
|
for param in mod.parameters():
|
||||||
|
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
|
||||||
|
return param_size
|
||||||
|
|
||||||
|
|
||||||
|
def is_inplace(n: Node):
|
||||||
|
"""Get the inplace argument from torch.fx.Node
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (Node): torch.fx.Node
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: indicates whether this op is inplace
|
||||||
|
"""
|
||||||
|
inplace = False
|
||||||
|
if n.op == "call_function":
|
||||||
|
inplace = n.kwargs.get("inplace", False)
|
||||||
|
if is_compatible_with_meta():
|
||||||
|
from .constants import ALIAS_ATEN
|
||||||
|
if n.target in ALIAS_ATEN:
|
||||||
|
inplace = True
|
||||||
|
elif n.op == "call_module":
|
||||||
|
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
|
||||||
|
|
||||||
|
return inplace
|
|
@ -11,7 +11,7 @@ from torch.utils._pytree import tree_map
|
||||||
from .._compatibility import compatibility
|
from .._compatibility import compatibility
|
||||||
from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
from .constants import ALIAS_ATEN, OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||||
from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase
|
from .dataflow import GraphInfo, Phase, autograd_graph_analysis, is_phase
|
||||||
from .memory import activation_size, parameter_size
|
from .memory_utils import activation_size, parameter_size
|
||||||
from .opcount import flop_mapping
|
from .opcount import flop_mapping
|
||||||
from .tensor import MetaTensor
|
from .tensor import MetaTensor
|
||||||
|
|
||||||
|
|
|
@ -1,58 +1,18 @@
|
||||||
from typing import Dict, List, Tuple, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import GraphModule, Node
|
from torch.fx import Node
|
||||||
|
|
||||||
from .._compatibility import compatibility, is_compatible_with_meta
|
from .._compatibility import compatibility, is_compatible_with_meta
|
||||||
|
from .memory_utils import activation_size
|
||||||
|
|
||||||
if is_compatible_with_meta():
|
if is_compatible_with_meta():
|
||||||
from .constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
from .constants import OUTPUT_SAVED_MOD, OUTPUT_SAVED_OPS
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"]
|
||||||
'activation_size', 'parameter_size', 'is_inplace', "calculate_fwd_in", "calculate_fwd_tmp", "calculate_fwd_out"
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
|
||||||
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
|
||||||
"""Calculate activation size of a node.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: The activation size
|
|
||||||
"""
|
|
||||||
act_size = 0
|
|
||||||
if isinstance(out, torch.Tensor):
|
|
||||||
act_size += out.numel() * torch.tensor([], dtype=out.dtype).element_size()
|
|
||||||
elif isinstance(out, dict):
|
|
||||||
value_list = [v for _, v in out.items()]
|
|
||||||
act_size += activation_size(value_list)
|
|
||||||
elif isinstance(out, tuple) or isinstance(out, list) or isinstance(out, set):
|
|
||||||
for element in out:
|
|
||||||
act_size += activation_size(element)
|
|
||||||
return act_size
|
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
|
||||||
def parameter_size(mod: torch.nn.Module) -> int:
|
|
||||||
"""Calculate parameter size of a node.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mod (torch.nn.Module): The target `torch.nn.Module`
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: The parameter size
|
|
||||||
"""
|
|
||||||
param_size = 0
|
|
||||||
for param in mod.parameters():
|
|
||||||
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
|
|
||||||
return param_size
|
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=False)
|
||||||
def calculate_fwd_in(n: Node) -> int:
|
def calculate_fwd_in(n: Node) -> int:
|
||||||
"""A helper function to calculate `fwd_in`
|
"""A helper function to calculate `fwd_in` (with sharding spec)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n (Node): a node from the graph
|
n (Node): a node from the graph
|
||||||
|
@ -60,11 +20,13 @@ def calculate_fwd_in(n: Node) -> int:
|
||||||
Returns:
|
Returns:
|
||||||
fwd_in (int): the result of `fwd_in`
|
fwd_in (int): the result of `fwd_in`
|
||||||
"""
|
"""
|
||||||
|
# TODO(super-dainiu): should divide the memory by sharding spec
|
||||||
return activation_size(n.meta["fwd_in"])
|
return activation_size(n.meta["fwd_in"])
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=False)
|
||||||
def calculate_fwd_tmp(n: Node) -> int:
|
def calculate_fwd_tmp(n: Node) -> int:
|
||||||
"""A helper function to calculate `fwd_tmp`
|
"""A helper function to calculate `fwd_tmp` (with sharding spec)
|
||||||
Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy.
|
Currently, `torch.nn.ReLU` behaves weirdly, so we have to patch it for accuracy.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -74,6 +36,7 @@ def calculate_fwd_tmp(n: Node) -> int:
|
||||||
fwd_tmp (int): the result of `fwd_tmp`
|
fwd_tmp (int): the result of `fwd_tmp`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# TODO(super-dainiu): should divide the memory by sharding spec
|
||||||
def is_relu_like_node(n: Node) -> bool:
|
def is_relu_like_node(n: Node) -> bool:
|
||||||
"""Check if a node is a ReLU-like node.
|
"""Check if a node is a ReLU-like node.
|
||||||
ReLU-like nodes have the following properties:
|
ReLU-like nodes have the following properties:
|
||||||
|
@ -107,8 +70,9 @@ def calculate_fwd_tmp(n: Node) -> int:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=False)
|
||||||
def calculate_fwd_out(n: Node) -> int:
|
def calculate_fwd_out(n: Node) -> int:
|
||||||
"""A helper function to calculate `fwd_out`
|
"""A helper function to calculate `fwd_out` (with sharding spec)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n (Node): a node from the graph
|
n (Node): a node from the graph
|
||||||
|
@ -117,6 +81,7 @@ def calculate_fwd_out(n: Node) -> int:
|
||||||
fwd_out (int): the result of `fwd_out`
|
fwd_out (int): the result of `fwd_out`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# TODO(super-dainiu): should divide the memory by sharding spec
|
||||||
def intersect(a, b):
|
def intersect(a, b):
|
||||||
return {k: a[k] for k in a if k in b}
|
return {k: a[k] for k in a if k in b}
|
||||||
|
|
||||||
|
@ -127,23 +92,23 @@ def calculate_fwd_out(n: Node) -> int:
|
||||||
return activation_size(intersect(fwd_in, fwd_out))
|
return activation_size(intersect(fwd_in, fwd_out))
|
||||||
|
|
||||||
|
|
||||||
def is_inplace(n: Node):
|
def calculate_fwd_time(n: Node) -> float:
|
||||||
"""Get the inplace argument from torch.fx.Node
|
"""A helper function to calculate `fwd_time` (with sharding spec)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node (Node): torch.fx.Node
|
n (Node): a node from the graph
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: indicates whether this op is inplace
|
fwd_time (float): the result of `fwd_time`
|
||||||
"""
|
"""
|
||||||
inplace = False
|
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
|
||||||
if n.op == "call_function":
|
return n.meta["fwd_flop"]
|
||||||
inplace = n.kwargs.get("inplace", False)
|
|
||||||
if is_compatible_with_meta():
|
|
||||||
from .constants import ALIAS_ATEN
|
|
||||||
if n.target in ALIAS_ATEN:
|
|
||||||
inplace = True
|
|
||||||
elif n.op == "call_module":
|
|
||||||
inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False)
|
|
||||||
|
|
||||||
return inplace
|
|
||||||
|
def calculate_bwd_time(n: Node) -> float:
|
||||||
|
"""A helper function to calculate `bwd_time` (with sharding spec)
|
||||||
|
Args:
|
||||||
|
n (Node): a node from the graph
|
||||||
|
Returns:
|
||||||
|
bwd_time (float): the result of `bwd_time`
|
||||||
|
"""
|
||||||
|
# TODO(super-dainiu): should divide the time by the number of GPUs as well as TFLOPs
|
||||||
|
return n.meta["bwd_flop"]
|
|
@ -1,7 +1,5 @@
|
||||||
from colossalai.fx.profiler.memory import activation_size
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import Node, Graph
|
from torch.fx import Graph, Node
|
||||||
from torch.fx.graph import _Namespace
|
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,13 +3,14 @@ from typing import Optional, Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
import torchvision.models as tm
|
import torchvision.models as tm
|
||||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
|
||||||
from colossalai.fx.profiler import (calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size)
|
|
||||||
from colossalai.fx.tracer.tracer import ColoTracer
|
|
||||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
|
||||||
from gpt_utils import gpt2_medium, gpt2_xl
|
from gpt_utils import gpt2_medium, gpt2_xl
|
||||||
from torch.fx import symbolic_trace
|
from torch.fx import symbolic_trace
|
||||||
|
|
||||||
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
|
from colossalai.fx.profiler import calculate_fwd_out, calculate_fwd_tmp, is_compatible_with_meta, parameter_size
|
||||||
|
from colossalai.fx.tracer.tracer import ColoTracer
|
||||||
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
|
||||||
if is_compatible_with_meta():
|
if is_compatible_with_meta():
|
||||||
from colossalai.fx.profiler import MetaTensor
|
from colossalai.fx.profiler import MetaTensor
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue