mirror of https://github.com/hpcaitech/ColossalAI
[fx] add profiler for fx nodes. (#1480)
* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] merge development into main (#1) * [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen. * [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. * [fx] fix lowercase naming conventions. * [fx] simplify test for ckpt. * [fx] add rules to linearize computation graphs for searching. (#2) * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] merge development into main (#1) * [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen. * [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. * [fx] fix lowercase naming conventions. * [fx] simplify test for ckpt. * [fx] fix test and algorithm bugs in activation checkpointing. * [fx] polish ckpt_test. * [fx] add rules to linearize computation graphs for searching. * [fx] remove chen_sqrt for sake of simplicity * [fx] remove chen_sqrt for sake of simplicity * [fx] remove chen_sqrt for sake of simplicity * [fx] remove chen_sqrt for sake of simplicity * [fx] fix inconsistencies. * [fx] fix MetaInfoProp. * [fx] fix MetaInfoProp. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] fix error in tests. * [fx] unfix bug. * [fx] unfix bug.pull/1493/head
parent
d39e11dffb
commit
32efe8e740
|
@ -73,10 +73,10 @@ def chen_greedy(gm: GraphModule) -> GraphModule:
|
|||
y = 0
|
||||
prev_idx = 2
|
||||
for (idx, n) in enumerate(gm.graph.nodes):
|
||||
temp += getattr(n, 'activation_size')
|
||||
temp += getattr(n, '__activation__')
|
||||
y = max(y, temp)
|
||||
if temp > b and n in ckpt_nodes:
|
||||
x += getattr(n, 'activation_size')
|
||||
x += getattr(n, '__activation__')
|
||||
temp = 0
|
||||
ckpt_intv.append((prev_idx, idx + 1))
|
||||
prev_idx = idx + 1
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
from operator import add, getitem
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch.fx.node import Node, map_aggregate
|
||||
from torch.fx.node import Node, map_aggregate, Argument, Target
|
||||
from typing import Any, Tuple, NamedTuple, Optional, Dict
|
||||
from functools import reduce
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
||||
from colossalai.fx.profiler import MetaProfile, profile_function, profile_module, calculate_activation_size, profile_method
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
|
@ -36,47 +38,11 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
|
|||
return TensorMetadata(shape, dtype, requires_grad, stride, numel, is_tensor)
|
||||
|
||||
|
||||
def _compute_activation_size(node_metadata: any) -> int:
|
||||
"""
|
||||
Compute numel of a node with ``tensor_meta`` attribute.
|
||||
"""
|
||||
node_numel = 0
|
||||
|
||||
if isinstance(node_metadata, TensorMetadata):
|
||||
node_numel += node_metadata.numel * torch.tensor([], dtype=node_metadata.dtype).element_size()
|
||||
elif isinstance(node_metadata, dict):
|
||||
value_list = [v for _, v in node_metadata.items()]
|
||||
node_numel += _compute_activation_size(value_list)
|
||||
else:
|
||||
for element in node_metadata:
|
||||
node_numel += _compute_activation_size(element)
|
||||
|
||||
return node_numel
|
||||
|
||||
|
||||
def _map_aggregate(arg, fn):
|
||||
"""
|
||||
Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
|
||||
"""
|
||||
if isinstance(arg, torch.Size):
|
||||
return fn(arg)
|
||||
if isinstance(arg, tuple):
|
||||
return tuple(map_aggregate(elem, fn) for elem in arg)
|
||||
elif isinstance(arg, list):
|
||||
return immutable_list(map_aggregate(elem, fn) for elem in arg)
|
||||
elif isinstance(arg, dict):
|
||||
return immutable_dict((k, map_aggregate(v, fn)) for k, v in arg.items())
|
||||
elif isinstance(arg, slice):
|
||||
return slice(map_aggregate(arg.start, fn), map_aggregate(arg.stop, fn), map_aggregate(arg.step, fn))
|
||||
else:
|
||||
return fn(arg)
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class MetaInfoProp(torch.fx.Interpreter):
|
||||
"""
|
||||
Execute an FX graph Node-by-Node and
|
||||
record the shape and type of the result
|
||||
Execute an FX graph Node-by-Node with meta tensor and
|
||||
record the shape, FLOPs, MACs and type of the result
|
||||
into the corresponding node.
|
||||
|
||||
Usage:
|
||||
|
@ -104,9 +70,32 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
|
||||
"""
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any:
|
||||
"""
|
||||
Add additional check for initial args to ensure all the tensor appears with `device='meta'`
|
||||
"""
|
||||
for elem in args:
|
||||
if isinstance(elem, torch.Tensor):
|
||||
assert elem.is_meta, "Input torch.Tensor are assumed to appear with device='meta'"
|
||||
return super().run(*args, initial_env, enable_io_processing)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def run_node(self, n: Node) -> Any:
|
||||
# TODO: We might run_node(n) with meta data, and count FLOPS for each node
|
||||
result = super().run_node(n)
|
||||
"""
|
||||
Run a specific node ``n`` and return the result.
|
||||
Calls into placeholder, get_attr, call_function,
|
||||
call_method, call_module, or output depending
|
||||
on ``node.op``
|
||||
|
||||
Args:
|
||||
n (Node): The Node to execute
|
||||
|
||||
Returns:
|
||||
Any: The result of executing ``n``
|
||||
"""
|
||||
result, profile = super().run_node(n)
|
||||
profile: MetaProfile
|
||||
|
||||
def extract_tensor_meta(obj):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
|
@ -114,29 +103,139 @@ class MetaInfoProp(torch.fx.Interpreter):
|
|||
else:
|
||||
return TensorMetadata(None, None, False, None, 0, False)
|
||||
|
||||
meta = _map_aggregate(result, extract_tensor_meta)
|
||||
meta = map_aggregate(result, extract_tensor_meta)
|
||||
n.meta['tensor_meta'] = meta
|
||||
|
||||
total_activation_size = 0
|
||||
total_param_size = 0
|
||||
if n.op == 'call_module':
|
||||
target_module = n.graph.owning_module.get_submodule(n.target)
|
||||
if not getattr(target_module, 'inplace', False):
|
||||
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
|
||||
for param in target_module.parameters():
|
||||
total_param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
|
||||
elif n.op == 'call_function':
|
||||
if 'inplace' not in n.kwargs:
|
||||
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
|
||||
else:
|
||||
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
|
||||
|
||||
setattr(n, 'node_size', total_activation_size + total_param_size)
|
||||
setattr(n, 'param_size', total_param_size)
|
||||
setattr(n, 'activation_size', total_activation_size)
|
||||
# TODO: the attribute node_size should be removed in the future
|
||||
setattr(n, 'node_size', profile.param + profile.activation)
|
||||
setattr(n, '__param__', profile.param)
|
||||
setattr(n, '__activation__', profile.activation)
|
||||
setattr(n, '__flops__', profile.flops)
|
||||
setattr(n, '__macs__', profile.macs)
|
||||
n.meta['type'] = type(result)
|
||||
return result
|
||||
|
||||
# Main Node running APIs
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``placeholder`` node. Note that this is stateful:
|
||||
``Interpreter`` maintains an internal iterator over
|
||||
arguments passed to ``run`` and this method returns
|
||||
next() on that iterator.
|
||||
|
||||
Args:
|
||||
target (Target): The call target for this node. See
|
||||
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
||||
details on semantics
|
||||
args (Tuple): Tuple of positional args for this invocation
|
||||
kwargs (Dict): Dict of keyword arguments for this invocation
|
||||
|
||||
Returns:
|
||||
result (Any): The argument value that was retrieved
|
||||
profile (MetaProfile): The meta profile of this node
|
||||
"""
|
||||
result = super().placeholder(target, args, kwargs)
|
||||
# A placeholder node only has activation
|
||||
return result, MetaProfile(0, calculate_activation_size(result), 0, 0)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``get_attr`` node. Will retrieve an attribute
|
||||
value from the ``Module`` hierarchy of ``self.module``.
|
||||
|
||||
Args:
|
||||
target (Target): The call target for this node. See
|
||||
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
||||
details on semantics
|
||||
args (Tuple): Tuple of positional args for this invocation
|
||||
kwargs (Dict): Dict of keyword arguments for this invocation
|
||||
|
||||
Return:
|
||||
result (Any): The argument value that was retrieved
|
||||
profile (MetaProfile): The meta profile of this node
|
||||
"""
|
||||
# A get_attr node never has parameters, activations, FLOPs, or MACs
|
||||
return super().get_attr(target, args, kwargs), MetaProfile(0, 0, 0, 0)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
|
||||
|
||||
Args:
|
||||
target (Target): The call target for this node. See
|
||||
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
||||
details on semantics
|
||||
args (Tuple): Tuple of positional args for this invocation
|
||||
kwargs (Dict): Dict of keyword arguments for this invocation
|
||||
|
||||
Return
|
||||
result (Any): The argument value that was retrieved
|
||||
profile (MetaProfile): The meta profile of this node
|
||||
"""
|
||||
assert not isinstance(target, str)
|
||||
return profile_function(target)(*args, **kwargs)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
|
||||
|
||||
Args:
|
||||
target (Target): The call target for this node. See
|
||||
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
||||
details on semantics
|
||||
args (Tuple): Tuple of positional args for this invocation
|
||||
kwargs (Dict): Dict of keyword arguments for this invocation
|
||||
|
||||
Return
|
||||
result (Any): The argument value that was retrieved
|
||||
profile (MetaProfile): The meta profile of this node
|
||||
"""
|
||||
return profile_method(target)(*args, **kwargs)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
|
||||
|
||||
Args:
|
||||
target (Target): The call target for this node. See
|
||||
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
||||
details on semantics
|
||||
args (Tuple): Tuple of positional args for this invocation
|
||||
kwargs (Dict): Dict of keyword arguments for this invocation
|
||||
|
||||
Return
|
||||
result (Any): The argument value that was retrieved
|
||||
profile (MetaProfile): The meta profile of this node
|
||||
"""
|
||||
# Retrieve executed args and kwargs values from the environment
|
||||
# Execute the method and return the result
|
||||
assert isinstance(target, str)
|
||||
submod = self.fetch_attr(target)
|
||||
return profile_module(submod)(*args, **kwargs)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
|
||||
"""
|
||||
Execute an ``output`` node. This really just retrieves
|
||||
the value referenced by the ``output`` node and returns it.
|
||||
|
||||
Args:
|
||||
target (Target): The call target for this node. See
|
||||
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
|
||||
details on semantics
|
||||
args (Tuple): Tuple of positional args for this invocation
|
||||
kwargs (Dict): Dict of keyword arguments for this invocation
|
||||
|
||||
Return:
|
||||
Any: The return value referenced by the output node
|
||||
"""
|
||||
return args[0], MetaProfile(0, 0, 0, 0)
|
||||
|
||||
def propagate(self, *args):
|
||||
"""
|
||||
Run `module` via interpretation and return the result and
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
from .registry import *
|
||||
from .profiler_function import *
|
||||
from .profiler_module import *
|
||||
from .utils import *
|
|
@ -0,0 +1,8 @@
|
|||
from .activation_function import *
|
||||
from .arithmetic import *
|
||||
from .embedding import *
|
||||
from .linear import *
|
||||
from .normalization import *
|
||||
from .pooling import *
|
||||
from .python_ops import *
|
||||
from .torch_ops import *
|
|
@ -0,0 +1,29 @@
|
|||
from typing import Tuple
|
||||
import torch
|
||||
from ..registry import meta_profiler_function
|
||||
|
||||
# TODO: different activation has different FLOPs count, currently unused.
|
||||
_multiplier = {
|
||||
torch.nn.functional.relu: 1,
|
||||
torch.nn.functional.prelu: 4,
|
||||
torch.nn.functional.sigmoid: 4,
|
||||
torch.nn.functional.tanh: 5,
|
||||
torch.nn.functional.leaky_relu: 3,
|
||||
torch.nn.functional.elu: 4,
|
||||
torch.nn.functional.relu6: 2,
|
||||
torch.nn.functional.gelu: 9,
|
||||
}
|
||||
|
||||
|
||||
@meta_profiler_function.register(torch.nn.functional.leaky_relu)
|
||||
@meta_profiler_function.register(torch.nn.functional.elu)
|
||||
@meta_profiler_function.register(torch.nn.functional.gelu)
|
||||
@meta_profiler_function.register(torch.nn.functional.relu6)
|
||||
@meta_profiler_function.register(torch.nn.functional.prelu)
|
||||
@meta_profiler_function.register(torch.nn.functional.relu)
|
||||
@meta_profiler_function.register(torch.nn.functional.sigmoid)
|
||||
@meta_profiler_function.register(torch.nn.functional.tanh)
|
||||
def torch_nn_func_non_linear_act(input: torch.Tensor, inplace: bool = False) -> Tuple[int, int]:
|
||||
flops = input.numel()
|
||||
macs = 0
|
||||
return flops, macs
|
|
@ -0,0 +1,83 @@
|
|||
from typing import Any, Optional, Tuple, Union
|
||||
import torch
|
||||
from ..registry import meta_profiler_function
|
||||
|
||||
|
||||
def _prod(dims):
|
||||
p = 1
|
||||
for v in dims:
|
||||
p *= v
|
||||
return p
|
||||
|
||||
|
||||
def _elementwise_flops_compute(input, other):
|
||||
# copied from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L763
|
||||
if not torch.is_tensor(input):
|
||||
if torch.is_tensor(other):
|
||||
return _prod(other.shape), 0
|
||||
else:
|
||||
return 1, 0
|
||||
elif not torch.is_tensor(other):
|
||||
return _prod(input.shape), 0
|
||||
else:
|
||||
dim_input = len(input.shape)
|
||||
dim_other = len(other.shape)
|
||||
max_dim = max(dim_input, dim_other)
|
||||
|
||||
final_shape = []
|
||||
for i in range(max_dim):
|
||||
in_i = input.shape[i] if i < dim_input else 1
|
||||
ot_i = other.shape[i] if i < dim_other else 1
|
||||
if in_i > ot_i:
|
||||
final_shape.append(in_i)
|
||||
else:
|
||||
final_shape.append(ot_i)
|
||||
flops = _prod(final_shape)
|
||||
return flops, 0
|
||||
|
||||
|
||||
@meta_profiler_function.register(torch.add)
|
||||
@meta_profiler_function.register('add') # for built-in op +
|
||||
@meta_profiler_function.register('iadd') # for built-in op +=
|
||||
@meta_profiler_function.register('sub') # for built-in op -
|
||||
@meta_profiler_function.register('isub') # for built-in op -=
|
||||
@meta_profiler_function.register('mul') # for built-in op *
|
||||
@meta_profiler_function.register('imul') # for built-in op *=
|
||||
def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
|
||||
return _elementwise_flops_compute(input, other)
|
||||
|
||||
|
||||
@meta_profiler_function.register(torch.abs)
|
||||
def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
|
||||
flops = input.numel()
|
||||
macs = 0
|
||||
return flops, macs
|
||||
|
||||
|
||||
@meta_profiler_function.register(torch.matmul)
|
||||
@meta_profiler_function.register('matmul') # for built-in op @
|
||||
@meta_profiler_function.register(torch.Tensor.matmul)
|
||||
def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
|
||||
macs = _prod(input.shape) * other.shape[-1]
|
||||
flops = 2 * macs
|
||||
return flops, macs
|
||||
|
||||
|
||||
@meta_profiler_function.register(torch.bmm)
|
||||
def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
|
||||
macs = _prod(input.shape) * other.shape[-1]
|
||||
flops = 2 * macs
|
||||
return flops, macs
|
||||
|
||||
|
||||
@meta_profiler_function.register(torch.var_mean)
|
||||
def torch_var_mean(input: torch.Tensor,
|
||||
dim: Union[int, Tuple[int, ...]],
|
||||
unbiased: Optional[bool] = True,
|
||||
keepdim: Optional[bool] = False,
|
||||
*,
|
||||
out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
|
||||
assert out is None, 'saving to out is not supported yet'
|
||||
flops = input.numel() * 3
|
||||
macs = 0
|
||||
return flops, macs
|
|
@ -0,0 +1,19 @@
|
|||
import torch
|
||||
from typing import Optional
|
||||
from ..registry import meta_profiler_function
|
||||
|
||||
|
||||
@meta_profiler_function.register(torch.nn.functional.embedding)
|
||||
def torch_nn_functional_embedding(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
padding_idx: Optional[int] = None,
|
||||
max_norm: Optional[float] = None,
|
||||
norm_type: float = 2.0,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# F.embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6)
|
||||
flops = 0
|
||||
macs = 0
|
||||
return flops, macs
|
|
@ -0,0 +1,13 @@
|
|||
from typing import Tuple
|
||||
import torch
|
||||
from ..registry import meta_profiler_function
|
||||
|
||||
|
||||
@meta_profiler_function.register(torch.nn.functional.linear)
|
||||
def torch_nn_linear(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None) -> Tuple[int, int]:
|
||||
out_features = weight.shape[0]
|
||||
macs = torch.numel(input) * out_features
|
||||
flops = 2 * macs
|
||||
if bias is not None:
|
||||
flops += bias.numel()
|
||||
return flops, macs
|
|
@ -0,0 +1,66 @@
|
|||
from typing import List, Optional, Tuple
|
||||
import torch
|
||||
from ..registry import meta_profiler_function
|
||||
|
||||
|
||||
@meta_profiler_function.register(torch.nn.functional.instance_norm)
|
||||
def torch_nn_func_instancenorm(
|
||||
input: torch.Tensor,
|
||||
running_mean: Optional[torch.Tensor] = None,
|
||||
running_var: Optional[torch.Tensor] = None,
|
||||
weight: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
use_input_stats: bool = True,
|
||||
momentum: float = 0.1,
|
||||
eps: float = 1e-5,
|
||||
):
|
||||
has_affine = weight is not None
|
||||
flops = input.numel() * (5 if has_affine else 4)
|
||||
macs = 0
|
||||
return flops, macs
|
||||
|
||||
|
||||
@meta_profiler_function.register(torch.nn.functional.group_norm)
|
||||
def torch_nn_func_groupnorm(input: torch.Tensor,
|
||||
num_groups: int,
|
||||
weight: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-5) -> Tuple[int, int]:
|
||||
has_affine = weight is not None
|
||||
flops = input.numel() * (5 if has_affine else 4)
|
||||
macs = 0
|
||||
return flops, macs
|
||||
|
||||
|
||||
@meta_profiler_function.register(torch.nn.functional.layer_norm)
|
||||
def torch_nn_func_layernorm(
|
||||
input: torch.Tensor,
|
||||
normalized_shape: List[int],
|
||||
weight: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-5,
|
||||
) -> Tuple[int, int]:
|
||||
has_affine = weight is not None
|
||||
flops = input.numel() * (5 if has_affine else 4)
|
||||
macs = 0
|
||||
return flops, macs
|
||||
|
||||
|
||||
@meta_profiler_function.register(torch.nn.functional.batch_norm)
|
||||
def torch_nn_func_batchnorm(
|
||||
input: torch.Tensor,
|
||||
running_mean: Optional[torch.Tensor],
|
||||
running_var: Optional[torch.Tensor],
|
||||
weight: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
training: bool = False,
|
||||
momentum: float = 0.1,
|
||||
eps: float = 1e-5,
|
||||
) -> Tuple[int, int]:
|
||||
has_affine = weight is not None
|
||||
if training:
|
||||
flops = input.numel() * (2 if has_affine else 1)
|
||||
else:
|
||||
flops = input.numel() * (5 if has_affine else 4)
|
||||
macs = 0
|
||||
return flops, macs
|
|
@ -0,0 +1,22 @@
|
|||
from typing import Tuple, Union
|
||||
import torch
|
||||
from ..registry import meta_profiler_function
|
||||
|
||||
|
||||
@meta_profiler_function.register(torch.nn.functional.avg_pool1d)
|
||||
@meta_profiler_function.register(torch.nn.functional.avg_pool2d)
|
||||
@meta_profiler_function.register(torch.nn.functional.avg_pool3d)
|
||||
@meta_profiler_function.register(torch.nn.functional.max_pool1d)
|
||||
@meta_profiler_function.register(torch.nn.functional.max_pool2d)
|
||||
@meta_profiler_function.register(torch.nn.functional.max_pool3d)
|
||||
@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool1d)
|
||||
@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool2d)
|
||||
@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool3d)
|
||||
@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool1d)
|
||||
@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool2d)
|
||||
@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool3d)
|
||||
def torch_nn_func_pooling(input: torch.Tensor, *args, **kwargs) -> Tuple[int, int]:
|
||||
# all pooling could be considered as going over each input element only once (https://stackoverflow.com/a/67301217)
|
||||
flops = input.numel()
|
||||
macs = 0
|
||||
return flops, macs
|
|
@ -0,0 +1,12 @@
|
|||
import operator
|
||||
from typing import Any, Tuple
|
||||
import torch
|
||||
from ..registry import meta_profiler_function
|
||||
from colossalai.fx.proxy import ColoProxy
|
||||
|
||||
|
||||
@meta_profiler_function.register(operator.getitem)
|
||||
def operator_getitem(a: Any, b: Any) -> Tuple[int, int]:
|
||||
flops = 0
|
||||
macs = 0
|
||||
return flops, macs
|
|
@ -0,0 +1,64 @@
|
|||
from typing import Any, Optional, Tuple
|
||||
import torch
|
||||
from ..registry import meta_profiler_function
|
||||
|
||||
|
||||
def _prod(dims):
|
||||
p = 1
|
||||
for v in dims:
|
||||
p *= v
|
||||
return p
|
||||
|
||||
|
||||
@meta_profiler_function.register(torch.arange)
|
||||
@meta_profiler_function.register(torch.finfo)
|
||||
@meta_profiler_function.register(torch.permute)
|
||||
@meta_profiler_function.register(torch.Tensor.permute)
|
||||
@meta_profiler_function.register(torch.Tensor.repeat)
|
||||
@meta_profiler_function.register(torch.index_select)
|
||||
@meta_profiler_function.register(torch.Tensor.index_select)
|
||||
@meta_profiler_function.register(torch.squeeze)
|
||||
@meta_profiler_function.register(torch.Tensor.squeeze)
|
||||
@meta_profiler_function.register(torch.unsqueeze)
|
||||
@meta_profiler_function.register(torch.Tensor.unsqueeze)
|
||||
@meta_profiler_function.register(torch.cat)
|
||||
@meta_profiler_function.register(torch.concat)
|
||||
@meta_profiler_function.register(torch.repeat_interleave)
|
||||
@meta_profiler_function.register(torch.Tensor.repeat_interleave)
|
||||
@meta_profiler_function.register(torch.flatten)
|
||||
@meta_profiler_function.register(torch.Tensor.flatten)
|
||||
@meta_profiler_function.register(torch.roll)
|
||||
@meta_profiler_function.register(torch.full)
|
||||
@meta_profiler_function.register(torch.Tensor.cpu)
|
||||
@meta_profiler_function.register(torch.Tensor.cuda)
|
||||
def torch_zero_flops_op(*args, **kwargs) -> Tuple[int, int]:
|
||||
flops = 0
|
||||
macs = 0
|
||||
return flops, macs
|
||||
|
||||
|
||||
@meta_profiler_function.register(torch.where)
|
||||
def torch_where(condition: torch.Tensor, x: Any, y: Any) -> Tuple[int, int]:
|
||||
# torch.where returns the broadcasted tensor of condition, x, and y,
|
||||
# so hack it by using addition
|
||||
flops = condition.numel()
|
||||
macs = 0
|
||||
return flops, macs
|
||||
|
||||
|
||||
@meta_profiler_function.register(torch.max)
|
||||
def torch_max(input: torch.Tensor,
|
||||
dim: int = None,
|
||||
keepdim: bool = False,
|
||||
*,
|
||||
out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
|
||||
macs = 0
|
||||
assert out is None, 'assigning value to out is not supported yet'
|
||||
if dim is not None:
|
||||
shape = list(input.shape)
|
||||
shape.pop(int(dim))
|
||||
flops = _prod(shape), macs
|
||||
return flops, macs
|
||||
else:
|
||||
flops = input.numel()
|
||||
return flops, macs
|
|
@ -0,0 +1,7 @@
|
|||
from .activation_function import *
|
||||
from .convolution import *
|
||||
from .embedding import *
|
||||
from .linear import *
|
||||
from .normalization import *
|
||||
from .pooling import *
|
||||
from .rnn import *
|
|
@ -0,0 +1,29 @@
|
|||
from typing import Tuple
|
||||
import torch
|
||||
from ..registry import meta_profiler_module
|
||||
|
||||
# TODO: different activation has different FLOPs count, currently unused.
|
||||
_multiplier = {
|
||||
torch.nn.ReLU: 1,
|
||||
torch.nn.PReLU: 4,
|
||||
torch.nn.Sigmoid: 4,
|
||||
torch.nn.Tanh: 5,
|
||||
torch.nn.LeakyReLU: 3,
|
||||
torch.nn.ELU: 4,
|
||||
torch.nn.ReLU6: 2,
|
||||
torch.nn.GELU: 9,
|
||||
}
|
||||
|
||||
|
||||
@meta_profiler_module.register(torch.nn.ELU)
|
||||
@meta_profiler_module.register(torch.nn.LeakyReLU)
|
||||
@meta_profiler_module.register(torch.nn.ReLU)
|
||||
@meta_profiler_module.register(torch.nn.GELU)
|
||||
@meta_profiler_module.register(torch.nn.Sigmoid)
|
||||
@meta_profiler_module.register(torch.nn.Tanh)
|
||||
@meta_profiler_module.register(torch.nn.ReLU6)
|
||||
@meta_profiler_module.register(torch.nn.PReLU)
|
||||
def torch_nn_non_linear_act(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
|
||||
flops = input.numel()
|
||||
macs = 0
|
||||
return flops, macs
|
|
@ -0,0 +1,157 @@
|
|||
import math
|
||||
from typing import Tuple
|
||||
import torch
|
||||
from ..registry import meta_profiler_module
|
||||
|
||||
|
||||
def _prod(dims):
|
||||
p = 1
|
||||
for v in dims:
|
||||
p *= v
|
||||
return p
|
||||
|
||||
|
||||
@meta_profiler_module.register(torch.nn.Conv1d)
|
||||
def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, int]:
|
||||
# the output shape is calculated using the formula stated
|
||||
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
||||
c_in, l_in = input.shape[-2:]
|
||||
c_out = self.out_channels
|
||||
l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] *
|
||||
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
|
||||
result_shape = input.shape[:-2] + (
|
||||
c_out,
|
||||
l_out,
|
||||
)
|
||||
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
|
||||
num_elem = _prod(result_shape)
|
||||
macs = macs_per_elem * num_elem
|
||||
flops = 2 * macs
|
||||
if self.bias is not None:
|
||||
flops += num_elem
|
||||
return flops, macs
|
||||
|
||||
|
||||
@meta_profiler_module.register(torch.nn.Conv2d)
|
||||
def torch_nn_conv2d(self: torch.nn.Conv2d, input: torch.Tensor) -> Tuple[int, int]:
|
||||
# the output shape is calculated using the formula stated
|
||||
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
|
||||
c_in, h_in, w_in = input.shape[-3:]
|
||||
c_out = self.out_channels
|
||||
h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] *
|
||||
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
|
||||
w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] *
|
||||
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
|
||||
result_shape = input.shape[:-3] + (
|
||||
c_out,
|
||||
h_out,
|
||||
w_out,
|
||||
)
|
||||
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
|
||||
num_elem = _prod(result_shape)
|
||||
macs = macs_per_elem * num_elem
|
||||
flops = 2 * macs
|
||||
if self.bias is not None:
|
||||
flops += num_elem
|
||||
return flops, macs
|
||||
|
||||
|
||||
@meta_profiler_module.register(torch.nn.Conv3d)
|
||||
def torch_nn_conv3d(self: torch.nn.Conv3d, input: torch.Tensor) -> Tuple[int, int]:
|
||||
# the output shape is calculated using the formula stated
|
||||
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html
|
||||
c_in, d_in, h_in, w_in = input.shape[-4:]
|
||||
c_out = self.out_channels
|
||||
d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] *
|
||||
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
|
||||
h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] *
|
||||
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
|
||||
w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] *
|
||||
(self.kernel_size[2] - 1) - 1) / self.stride[2] + 1)
|
||||
result_shape = input.shape[:-4] + (
|
||||
c_out,
|
||||
d_out,
|
||||
h_out,
|
||||
w_out,
|
||||
)
|
||||
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
|
||||
num_elem = _prod(result_shape)
|
||||
macs = macs_per_elem * num_elem
|
||||
flops = 2 * macs
|
||||
if self.bias is not None:
|
||||
flops += num_elem
|
||||
return flops, macs
|
||||
|
||||
|
||||
@meta_profiler_module.register(torch.nn.ConvTranspose1d)
|
||||
def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor) -> Tuple[int, int]:
|
||||
# the output shape is calculated using the formula stated
|
||||
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
|
||||
c_in, l_in = input.shape[-2:]
|
||||
c_out = self.out_channels
|
||||
l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
|
||||
(self.kernel_size[0] - 1) + self.output_padding[0] + 1)
|
||||
result_shape = input.shape[:-2] + (
|
||||
c_out,
|
||||
l_out,
|
||||
)
|
||||
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
|
||||
num_elem = _prod(
|
||||
input.shape
|
||||
) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604
|
||||
macs = macs_per_elem * num_elem
|
||||
flops = 2 * macs
|
||||
if self.bias is not None:
|
||||
flops += _prod(result_shape)
|
||||
return flops, macs
|
||||
|
||||
|
||||
@meta_profiler_module.register(torch.nn.ConvTranspose2d)
|
||||
def torch_nn_convtranspose2d(self: torch.nn.ConvTranspose2d, input: torch.Tensor) -> Tuple[int, int]:
|
||||
# the output shape is calculated using the formula stated
|
||||
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
|
||||
c_in, h_in, w_in = input.shape[-3:]
|
||||
c_out = self.out_channels
|
||||
h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
|
||||
(self.kernel_size[0] - 1) + self.output_padding[0] + 1)
|
||||
w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
|
||||
(self.kernel_size[1] - 1) + self.output_padding[1] + 1)
|
||||
result_shape = input.shape[:-3] + (
|
||||
c_out,
|
||||
h_out,
|
||||
w_out,
|
||||
)
|
||||
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
|
||||
num_elem = _prod(input.shape)
|
||||
macs = macs_per_elem * num_elem
|
||||
flops = 2 * macs
|
||||
if self.bias is not None:
|
||||
flops += _prod(result_shape)
|
||||
return flops, macs
|
||||
|
||||
|
||||
@meta_profiler_module.register(torch.nn.ConvTranspose3d)
|
||||
def torch_nn_convtranspose3d(self: torch.nn.ConvTranspose3d, input: torch.Tensor) -> Tuple[int, int]:
|
||||
# the output shape is calculated using the formula stated
|
||||
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html
|
||||
c_in, d_in, h_in, w_in = input.shape[-4:]
|
||||
c_out = self.out_channels
|
||||
d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
|
||||
(self.kernel_size[0] - 1) + self.output_padding[0] + 1)
|
||||
h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
|
||||
(self.kernel_size[1] - 1) + self.output_padding[1] + 1)
|
||||
w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] *
|
||||
(self.kernel_size[2] - 1) + self.output_padding[2] + 1)
|
||||
result_shape = input.shape[:-4] + (
|
||||
c_out,
|
||||
d_out,
|
||||
h_out,
|
||||
w_out,
|
||||
)
|
||||
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
|
||||
num_elem = _prod(input.shape)
|
||||
macs = macs_per_elem * num_elem
|
||||
flops = 2 * macs
|
||||
if self.bias is not None:
|
||||
flops += _prod(result_shape)
|
||||
return flops, macs
|
|
@ -0,0 +1,11 @@
|
|||
from typing import Tuple
|
||||
import torch
|
||||
from ..registry import meta_profiler_module
|
||||
|
||||
|
||||
@meta_profiler_module.register(torch.nn.Embedding)
|
||||
def torch_nn_embedding(self: torch.nn.Embedding, input: torch.Tensor) -> Tuple[int, int]:
|
||||
# nn.Embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6)
|
||||
flops = 0
|
||||
macs = 0
|
||||
return flops, macs
|
|
@ -0,0 +1,13 @@
|
|||
from typing import Tuple
|
||||
import torch
|
||||
from ..registry import meta_profiler_module
|
||||
|
||||
|
||||
@meta_profiler_module.register(torch.nn.Linear)
|
||||
def torch_nn_linear(self: torch.nn.Linear, input: torch.Tensor) -> Tuple[int, int]:
|
||||
out_features = self.weight.shape[0]
|
||||
macs = torch.numel(input) * out_features
|
||||
flops = 2 * macs
|
||||
if self.bias is not None:
|
||||
flops += self.bias.numel()
|
||||
return flops, macs
|
|
@ -0,0 +1,33 @@
|
|||
from typing import Tuple, Union
|
||||
import torch
|
||||
from ..registry import meta_profiler_module
|
||||
|
||||
|
||||
@meta_profiler_module.register(torch.nn.InstanceNorm1d)
|
||||
@meta_profiler_module.register(torch.nn.InstanceNorm2d)
|
||||
@meta_profiler_module.register(torch.nn.InstanceNorm3d)
|
||||
@meta_profiler_module.register(torch.nn.LayerNorm)
|
||||
@meta_profiler_module.register(torch.nn.GroupNorm)
|
||||
@meta_profiler_module.register(torch.nn.BatchNorm1d)
|
||||
@meta_profiler_module.register(torch.nn.BatchNorm2d)
|
||||
@meta_profiler_module.register(torch.nn.BatchNorm3d)
|
||||
def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d,
|
||||
torch.nn.BatchNorm3d], input: torch.Tensor) -> Tuple[int, int]:
|
||||
# adopted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L615
|
||||
has_affine = self.weight is not None
|
||||
if self.training:
|
||||
flops = input.numel() * (2 if has_affine else 1)
|
||||
else:
|
||||
flops = input.numel() * (5 if has_affine else 4)
|
||||
macs = 0
|
||||
return flops, macs
|
||||
|
||||
|
||||
try:
|
||||
import apex
|
||||
meta_profiler_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize)
|
||||
meta_profiler_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize)
|
||||
meta_profiler_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize)
|
||||
meta_profiler_module.register(apex.normalization.MixedFusedRMSNorm)(torch_nn_normalize)
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
|
@ -0,0 +1,22 @@
|
|||
from typing import Tuple
|
||||
import torch
|
||||
from ..registry import meta_profiler_module
|
||||
|
||||
|
||||
@meta_profiler_module.register(torch.nn.AvgPool1d)
|
||||
@meta_profiler_module.register(torch.nn.AvgPool2d)
|
||||
@meta_profiler_module.register(torch.nn.AvgPool3d)
|
||||
@meta_profiler_module.register(torch.nn.MaxPool1d)
|
||||
@meta_profiler_module.register(torch.nn.MaxPool2d)
|
||||
@meta_profiler_module.register(torch.nn.MaxPool3d)
|
||||
@meta_profiler_module.register(torch.nn.AdaptiveAvgPool1d)
|
||||
@meta_profiler_module.register(torch.nn.AdaptiveMaxPool1d)
|
||||
@meta_profiler_module.register(torch.nn.AdaptiveAvgPool2d)
|
||||
@meta_profiler_module.register(torch.nn.AdaptiveMaxPool2d)
|
||||
@meta_profiler_module.register(torch.nn.AdaptiveAvgPool3d)
|
||||
@meta_profiler_module.register(torch.nn.AdaptiveMaxPool3d)
|
||||
def torch_nn_pooling(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
|
||||
# all pooling could be considered as going over each input element only once (https://stackoverflow.com/a/67301217)
|
||||
flops = input.numel()
|
||||
macs = 0
|
||||
return flops, macs
|
|
@ -0,0 +1,13 @@
|
|||
import torch
|
||||
from ..registry import meta_profiler_module
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
# TODO: calculate rnn FLOPs
|
||||
@meta_profiler_module.register(torch.nn.GRU)
|
||||
@meta_profiler_module.register(torch.nn.RNN)
|
||||
def torch_nn_rnn(self: torch.nn.Module, input: torch.Tensor, hx: torch.Tensor) -> Tuple[int, int]:
|
||||
raise NotImplementedError
|
||||
flops = 0
|
||||
macs = 0
|
||||
return flops, macs
|
|
@ -0,0 +1,25 @@
|
|||
class ProfilerRegistry:
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.store = {}
|
||||
|
||||
def register(self, source):
|
||||
|
||||
def wrapper(func):
|
||||
self.store[source] = func
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
def get(self, source):
|
||||
assert source in self.store
|
||||
target = self.store[source]
|
||||
return target
|
||||
|
||||
def has(self, source):
|
||||
return source in self.store
|
||||
|
||||
|
||||
meta_profiler_function = ProfilerRegistry(name='patched_functions_for_meta_profile')
|
||||
meta_profiler_module = ProfilerRegistry(name='patched_modules_for_meta_profile')
|
|
@ -0,0 +1,180 @@
|
|||
from functools import partial
|
||||
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
|
||||
from typing import Callable, NamedTuple, Any, Dict, Tuple
|
||||
import torch
|
||||
from torch.fx.node import Argument, Target
|
||||
from torch.fx._compatibility import compatibility
|
||||
from colossalai.fx.tracer.meta_patch import meta_patched_function, meta_patched_module
|
||||
from . import meta_profiler_function, meta_profiler_module
|
||||
|
||||
__all__ = [
|
||||
'MetaProfile', 'profile_function', 'profile_module', 'profile_method', 'calculate_activation_size',
|
||||
'calculate_param_size'
|
||||
]
|
||||
|
||||
# TODO fill out the inplace ops
|
||||
INPLACE_OPS = [
|
||||
add,
|
||||
sub,
|
||||
mul,
|
||||
floordiv,
|
||||
neg,
|
||||
pos,
|
||||
getitem,
|
||||
setitem,
|
||||
torch.Tensor.cpu,
|
||||
]
|
||||
|
||||
# TODO check that call_methods are indeed inplace
|
||||
INPLACE_METHOD = [
|
||||
'transpose',
|
||||
'permute',
|
||||
]
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class MetaProfile(NamedTuple):
|
||||
# MetaProfile is a structure containing pertinent information
|
||||
# about a node within a torch.fx GraphModule.
|
||||
|
||||
param: int
|
||||
activation: int
|
||||
flops: int
|
||||
macs: int
|
||||
|
||||
|
||||
def calculate_activation_size(activation: any) -> int:
|
||||
"""
|
||||
Calculate activation size of a node.
|
||||
"""
|
||||
activation_size = 0
|
||||
if isinstance(activation, torch.Tensor):
|
||||
activation_size += activation.numel() * torch.tensor([], dtype=activation.dtype).element_size()
|
||||
elif isinstance(activation, dict):
|
||||
value_list = [v for _, v in activation.items()]
|
||||
activation_size += calculate_activation_size(value_list)
|
||||
else:
|
||||
for element in activation:
|
||||
activation_size += calculate_activation_size(element)
|
||||
return activation_size
|
||||
|
||||
|
||||
def calculate_param_size(mod: torch.nn.Module) -> int:
|
||||
"""
|
||||
Calculate param size of a node.
|
||||
"""
|
||||
param_size = 0
|
||||
for param in mod.parameters():
|
||||
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
|
||||
return param_size
|
||||
|
||||
|
||||
def profile_function(target: 'Target') -> Callable:
|
||||
"""
|
||||
Wrap a `call_function` node or `torch.nn.functional` in order to
|
||||
record the memory cost and FLOPs of the execution.
|
||||
|
||||
Warnings:
|
||||
You may only use tensors with `device=meta` for this wrapped function.
|
||||
Only original `torch.nn.functional` are available.
|
||||
|
||||
Usage:
|
||||
input = torch.rand(100, 100, 100, 100, device='meta')
|
||||
func = torch.nn.functional.relu
|
||||
output, profile = profile_function(func)(input, inplace=False)
|
||||
print(f"Profiling function {func},")
|
||||
print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
|
||||
"""
|
||||
|
||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
assert meta_profiler_function.has(target) or meta_profiler_function.has(
|
||||
target.__name__), f"Colossal-AI hasn't supported profiling for {target}, you might manually patch it."
|
||||
|
||||
# call_function has no parameters
|
||||
param_size = 0
|
||||
activation_size = 0
|
||||
result = func(*args, **kwargs)
|
||||
if target not in INPLACE_OPS and not kwargs.get('inplace', False):
|
||||
activation_size += calculate_activation_size(result)
|
||||
if meta_profiler_function.has(target):
|
||||
profiler = meta_profiler_function.get(target)
|
||||
else:
|
||||
profiler = meta_profiler_function.get(target.__name__)
|
||||
flops, macs = profiler(*args, **kwargs)
|
||||
return result, MetaProfile(param_size, activation_size, flops, macs)
|
||||
|
||||
f.__name__ = target.__name__
|
||||
# fetch patched function
|
||||
if meta_patched_function.has(target):
|
||||
func = meta_patched_function.get(target)
|
||||
elif meta_patched_function.has(target.__name__):
|
||||
func = meta_patched_function.get(target.__name__)
|
||||
else:
|
||||
func = target
|
||||
return f
|
||||
|
||||
|
||||
def profile_method(target: 'Target') -> Callable:
|
||||
"""
|
||||
Wrap a `call_method` node
|
||||
record the memory cost and FLOPs of the execution.
|
||||
|
||||
Warnings:
|
||||
This is not fully implemented and you may follow the error message to debug.
|
||||
"""
|
||||
|
||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
# args[0] is the `self` object for this method call
|
||||
self_obj, *args_tail = args
|
||||
|
||||
# Execute the method and return the result
|
||||
assert isinstance(target, str), f'{target} instance is not str.'
|
||||
result = getattr(self_obj, target)(*args_tail, **kwargs)
|
||||
assert target in INPLACE_METHOD, f'Please check {target} is an inplace method. If so, add target to INPLACE_METHOD={INPLACE_METHOD}.'
|
||||
|
||||
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
|
||||
param_size = 0
|
||||
activation_size = 0
|
||||
flops = 0
|
||||
macs = 0
|
||||
return result, MetaProfile(param_size, activation_size, flops, macs)
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def profile_module(module: torch.nn.Module) -> Callable:
|
||||
"""
|
||||
Wrap a `call_module` node or `torch.nn` in order to
|
||||
record the memory cost and FLOPs of the execution.
|
||||
|
||||
Warnings:
|
||||
You may only use tensors with `device=meta` for this wrapped function.
|
||||
Only original `torch.nn` are available.
|
||||
|
||||
Usage:
|
||||
input = torch.rand(4, 3, 224, 224, device='meta')
|
||||
mod = torch.nn.Conv2d(3, 128, 3)
|
||||
output, profile = profile_module(mod)(input)
|
||||
print(f"Profiling function {mod},")
|
||||
print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
|
||||
"""
|
||||
|
||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
assert meta_profiler_module.has(
|
||||
type(module)), f"Colossal-AI hasn't supported profiling for {module}, you might manually patch it."
|
||||
param_size = calculate_param_size(module)
|
||||
activation_size = 0
|
||||
result = func(*args, **kwargs)
|
||||
if not getattr(module, 'inplace', False):
|
||||
activation_size += calculate_activation_size(result)
|
||||
profiler = meta_profiler_module.get(type(module))
|
||||
flops, macs = profiler(module, *args, **kwargs)
|
||||
return result, MetaProfile(param_size, activation_size, flops, macs)
|
||||
|
||||
f.__name__ = module.__class__.__name__
|
||||
# fetch patched module
|
||||
if meta_patched_module.has(type(module)):
|
||||
func = partial(meta_patched_module.get(type(module)), module)
|
||||
else:
|
||||
func = module.forward
|
||||
return f
|
|
@ -68,7 +68,7 @@ def _run_ckpt_solver(rank):
|
|||
|
||||
tracer = ColoTracer(trace_act_ckpt=False)
|
||||
|
||||
data = torch.rand(2, 3, 32, 32)
|
||||
data = torch.rand(2, 3, 32, 32, device='meta')
|
||||
for solver in SOLVERS:
|
||||
for model_cls in MODEL_LIST:
|
||||
m = model_cls(num_classes=5)
|
||||
|
@ -98,7 +98,7 @@ def _run_ckpt_solver_torch11(rank):
|
|||
|
||||
tracer = ColoTracer(trace_act_ckpt=False)
|
||||
|
||||
data = torch.rand(2, 3, 32, 32)
|
||||
data = torch.rand(2, 3, 32, 32, device='meta')
|
||||
for solver in SOLVERS:
|
||||
for model_cls in MODEL_LIST:
|
||||
m = model_cls(num_classes=5)
|
||||
|
|
|
@ -32,7 +32,7 @@ class MLP(torch.nn.Module):
|
|||
|
||||
def test_comm_size_compute():
|
||||
model = MLP(MODEL_DIM)
|
||||
input_sample = torch.rand(BATCH_SIZE, MODEL_DIM)
|
||||
input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta')
|
||||
gm = symbolic_trace(model)
|
||||
MetaInfoProp(gm).run(input_sample)
|
||||
annotated_model = uniform_split_pass(gm, PIPELINE_SIZE)
|
||||
|
|
|
@ -20,17 +20,20 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor):
|
|||
|
||||
def test_meta_info_prop():
|
||||
model = torch.nn.Linear(DIM_IN, DIM_OUT)
|
||||
input_sample = torch.rand(BATCH_SIZE, DIM_IN)
|
||||
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta')
|
||||
orig_output = model(input_sample)
|
||||
gm = symbolic_trace(model)
|
||||
for node in gm.graph.nodes:
|
||||
assert not hasattr(node,
|
||||
'node_size'), 'The attribute Node.node_size should not exist before MetaInfoProp procedure'
|
||||
assert not hasattr(node,
|
||||
'param_size'), 'The attribute Node.param_size should not exist before MetaInfoProp procedure'
|
||||
'__param__'), 'The attribute Node.__param__ should not exist before MetaInfoProp procedure'
|
||||
assert not hasattr(
|
||||
node,
|
||||
'activation_size'), 'The attribute Node.activation_size should not exist before MetaInfoProp procedure'
|
||||
node, '__activation__'), 'The attribute Node.__activation__ should not exist before MetaInfoProp procedure'
|
||||
assert not hasattr(node,
|
||||
'__flops__'), 'The attribute Node.__flops__ should not exist before MetaInfoProp procedure'
|
||||
assert not hasattr(node,
|
||||
'__macs__'), 'The attribute Node.__macs__ should not exist before MetaInfoProp procedure'
|
||||
MetaInfoProp(gm).run(input_sample)
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == 'placeholder':
|
||||
|
@ -38,9 +41,11 @@ def test_meta_info_prop():
|
|||
if node.op == 'output':
|
||||
meta_check(node.meta['tensor_meta'], orig_output)
|
||||
assert hasattr(node, 'node_size'), 'The attribute Node.node_size should exist after MetaInfoProp procedure'
|
||||
assert hasattr(node, 'param_size'), 'The attribute Node.param_size should exist after MetaInfoProp procedure'
|
||||
assert hasattr(
|
||||
node, 'activation_size'), 'The attribute Node.activation_size should exist after MetaInfoProp procedure'
|
||||
assert hasattr(node, '__param__'), 'The attribute Node.__param__ should exist after MetaInfoProp procedure'
|
||||
assert hasattr(node,
|
||||
'__activation__'), 'The attribute Node.__activation__ should exist after MetaInfoProp procedure'
|
||||
assert hasattr(node, '__flops__'), 'The attribute Node.__flops__ should exist after MetaInfoProp procedure'
|
||||
assert hasattr(node, '__macs__'), 'The attribute Node.__macs__ should exist after MetaInfoProp procedure'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue