[fx] add profiler for fx nodes. (#1480)

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] merge development into main (#1)

* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.

* [fx] simplify test for ckpt.

* [fx] add rules to linearize computation graphs for searching. (#2)

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] merge development into main (#1)

* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.

* [fx] simplify test for ckpt.

* [fx] fix test and algorithm bugs in activation checkpointing.

* [fx] polish ckpt_test.

* [fx] add rules to linearize computation graphs for searching.

* [fx] remove chen_sqrt for sake of simplicity

* [fx] remove chen_sqrt for sake of simplicity

* [fx] remove chen_sqrt for sake of simplicity

* [fx] remove chen_sqrt for sake of simplicity

* [fx] fix inconsistencies.

* [fx] fix MetaInfoProp.

* [fx] fix MetaInfoProp.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] add profiler for fx nodes.

* [fx] add profiler for fx nodes.

* [fx] add profiler for fx nodes.

* [fx] add profiler for fx nodes.

* [fx] add profiler for fx nodes.

* [fx] add profiler for fx nodes.

* [fx] add profiler for fx nodes.

* [fx] fix error in tests.

* [fx] unfix bug.

* [fx] unfix bug.
pull/1493/head
Super Daniel 2022-08-24 16:22:44 +08:00 committed by GitHub
parent d39e11dffb
commit 32efe8e740
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 985 additions and 71 deletions

View File

@ -73,10 +73,10 @@ def chen_greedy(gm: GraphModule) -> GraphModule:
y = 0
prev_idx = 2
for (idx, n) in enumerate(gm.graph.nodes):
temp += getattr(n, 'activation_size')
temp += getattr(n, '__activation__')
y = max(y, temp)
if temp > b and n in ckpt_nodes:
x += getattr(n, 'activation_size')
x += getattr(n, '__activation__')
temp = 0
ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1

View File

@ -1,10 +1,12 @@
from operator import add, getitem
import torch
import torch.fx
from torch.fx.node import Node, map_aggregate
from torch.fx.node import Node, map_aggregate, Argument, Target
from typing import Any, Tuple, NamedTuple, Optional, Dict
from functools import reduce
from torch.fx._compatibility import compatibility
from torch.fx.immutable_collections import immutable_dict, immutable_list
from colossalai.fx.profiler import MetaProfile, profile_function, profile_module, calculate_activation_size, profile_method
@compatibility(is_backward_compatible=True)
@ -36,47 +38,11 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata:
return TensorMetadata(shape, dtype, requires_grad, stride, numel, is_tensor)
def _compute_activation_size(node_metadata: any) -> int:
"""
Compute numel of a node with ``tensor_meta`` attribute.
"""
node_numel = 0
if isinstance(node_metadata, TensorMetadata):
node_numel += node_metadata.numel * torch.tensor([], dtype=node_metadata.dtype).element_size()
elif isinstance(node_metadata, dict):
value_list = [v for _, v in node_metadata.items()]
node_numel += _compute_activation_size(value_list)
else:
for element in node_metadata:
node_numel += _compute_activation_size(element)
return node_numel
def _map_aggregate(arg, fn):
"""
Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys.
"""
if isinstance(arg, torch.Size):
return fn(arg)
if isinstance(arg, tuple):
return tuple(map_aggregate(elem, fn) for elem in arg)
elif isinstance(arg, list):
return immutable_list(map_aggregate(elem, fn) for elem in arg)
elif isinstance(arg, dict):
return immutable_dict((k, map_aggregate(v, fn)) for k, v in arg.items())
elif isinstance(arg, slice):
return slice(map_aggregate(arg.start, fn), map_aggregate(arg.stop, fn), map_aggregate(arg.step, fn))
else:
return fn(arg)
@compatibility(is_backward_compatible=True)
class MetaInfoProp(torch.fx.Interpreter):
"""
Execute an FX graph Node-by-Node and
record the shape and type of the result
Execute an FX graph Node-by-Node with meta tensor and
record the shape, FLOPs, MACs and type of the result
into the corresponding node.
Usage:
@ -104,9 +70,32 @@ class MetaInfoProp(torch.fx.Interpreter):
"""
@compatibility(is_backward_compatible=True)
def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any:
"""
Add additional check for initial args to ensure all the tensor appears with `device='meta'`
"""
for elem in args:
if isinstance(elem, torch.Tensor):
assert elem.is_meta, "Input torch.Tensor are assumed to appear with device='meta'"
return super().run(*args, initial_env, enable_io_processing)
@compatibility(is_backward_compatible=True)
def run_node(self, n: Node) -> Any:
# TODO: We might run_node(n) with meta data, and count FLOPS for each node
result = super().run_node(n)
"""
Run a specific node ``n`` and return the result.
Calls into placeholder, get_attr, call_function,
call_method, call_module, or output depending
on ``node.op``
Args:
n (Node): The Node to execute
Returns:
Any: The result of executing ``n``
"""
result, profile = super().run_node(n)
profile: MetaProfile
def extract_tensor_meta(obj):
if isinstance(obj, torch.Tensor):
@ -114,29 +103,139 @@ class MetaInfoProp(torch.fx.Interpreter):
else:
return TensorMetadata(None, None, False, None, 0, False)
meta = _map_aggregate(result, extract_tensor_meta)
meta = map_aggregate(result, extract_tensor_meta)
n.meta['tensor_meta'] = meta
total_activation_size = 0
total_param_size = 0
if n.op == 'call_module':
target_module = n.graph.owning_module.get_submodule(n.target)
if not getattr(target_module, 'inplace', False):
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
for param in target_module.parameters():
total_param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
elif n.op == 'call_function':
if 'inplace' not in n.kwargs:
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
else:
total_activation_size = _compute_activation_size(n.meta['tensor_meta'])
setattr(n, 'node_size', total_activation_size + total_param_size)
setattr(n, 'param_size', total_param_size)
setattr(n, 'activation_size', total_activation_size)
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', profile.param + profile.activation)
setattr(n, '__param__', profile.param)
setattr(n, '__activation__', profile.activation)
setattr(n, '__flops__', profile.flops)
setattr(n, '__macs__', profile.macs)
n.meta['type'] = type(result)
return result
# Main Node running APIs
@compatibility(is_backward_compatible=True)
def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``placeholder`` node. Note that this is stateful:
``Interpreter`` maintains an internal iterator over
arguments passed to ``run`` and this method returns
next() on that iterator.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Returns:
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
"""
result = super().placeholder(target, args, kwargs)
# A placeholder node only has activation
return result, MetaProfile(0, calculate_activation_size(result), 0, 0)
@compatibility(is_backward_compatible=True)
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``get_attr`` node. Will retrieve an attribute
value from the ``Module`` hierarchy of ``self.module``.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return:
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
"""
# A get_attr node never has parameters, activations, FLOPs, or MACs
return super().get_attr(target, args, kwargs), MetaProfile(0, 0, 0, 0)
@compatibility(is_backward_compatible=True)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_function`` node with meta tensor and return the result and its meta profile.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
"""
assert not isinstance(target, str)
return profile_function(target)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_method`` node with meta tensor and return the result and its meta profile.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
"""
return profile_method(target)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute a ``call_module`` node with meta tensor and return the result and its meta profile.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
"""
# Retrieve executed args and kwargs values from the environment
# Execute the method and return the result
assert isinstance(target, str)
submod = self.fetch_attr(target)
return profile_module(submod)(*args, **kwargs)
@compatibility(is_backward_compatible=True)
def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
"""
Execute an ``output`` node. This really just retrieves
the value referenced by the ``output`` node and returns it.
Args:
target (Target): The call target for this node. See
`Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for
details on semantics
args (Tuple): Tuple of positional args for this invocation
kwargs (Dict): Dict of keyword arguments for this invocation
Return:
Any: The return value referenced by the output node
"""
return args[0], MetaProfile(0, 0, 0, 0)
def propagate(self, *args):
"""
Run `module` via interpretation and return the result and

View File

@ -0,0 +1,4 @@
from .registry import *
from .profiler_function import *
from .profiler_module import *
from .utils import *

View File

@ -0,0 +1,8 @@
from .activation_function import *
from .arithmetic import *
from .embedding import *
from .linear import *
from .normalization import *
from .pooling import *
from .python_ops import *
from .torch_ops import *

View File

@ -0,0 +1,29 @@
from typing import Tuple
import torch
from ..registry import meta_profiler_function
# TODO: different activation has different FLOPs count, currently unused.
_multiplier = {
torch.nn.functional.relu: 1,
torch.nn.functional.prelu: 4,
torch.nn.functional.sigmoid: 4,
torch.nn.functional.tanh: 5,
torch.nn.functional.leaky_relu: 3,
torch.nn.functional.elu: 4,
torch.nn.functional.relu6: 2,
torch.nn.functional.gelu: 9,
}
@meta_profiler_function.register(torch.nn.functional.leaky_relu)
@meta_profiler_function.register(torch.nn.functional.elu)
@meta_profiler_function.register(torch.nn.functional.gelu)
@meta_profiler_function.register(torch.nn.functional.relu6)
@meta_profiler_function.register(torch.nn.functional.prelu)
@meta_profiler_function.register(torch.nn.functional.relu)
@meta_profiler_function.register(torch.nn.functional.sigmoid)
@meta_profiler_function.register(torch.nn.functional.tanh)
def torch_nn_func_non_linear_act(input: torch.Tensor, inplace: bool = False) -> Tuple[int, int]:
flops = input.numel()
macs = 0
return flops, macs

View File

@ -0,0 +1,83 @@
from typing import Any, Optional, Tuple, Union
import torch
from ..registry import meta_profiler_function
def _prod(dims):
p = 1
for v in dims:
p *= v
return p
def _elementwise_flops_compute(input, other):
# copied from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L763
if not torch.is_tensor(input):
if torch.is_tensor(other):
return _prod(other.shape), 0
else:
return 1, 0
elif not torch.is_tensor(other):
return _prod(input.shape), 0
else:
dim_input = len(input.shape)
dim_other = len(other.shape)
max_dim = max(dim_input, dim_other)
final_shape = []
for i in range(max_dim):
in_i = input.shape[i] if i < dim_input else 1
ot_i = other.shape[i] if i < dim_other else 1
if in_i > ot_i:
final_shape.append(in_i)
else:
final_shape.append(ot_i)
flops = _prod(final_shape)
return flops, 0
@meta_profiler_function.register(torch.add)
@meta_profiler_function.register('add') # for built-in op +
@meta_profiler_function.register('iadd') # for built-in op +=
@meta_profiler_function.register('sub') # for built-in op -
@meta_profiler_function.register('isub') # for built-in op -=
@meta_profiler_function.register('mul') # for built-in op *
@meta_profiler_function.register('imul') # for built-in op *=
def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
return _elementwise_flops_compute(input, other)
@meta_profiler_function.register(torch.abs)
def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
flops = input.numel()
macs = 0
return flops, macs
@meta_profiler_function.register(torch.matmul)
@meta_profiler_function.register('matmul') # for built-in op @
@meta_profiler_function.register(torch.Tensor.matmul)
def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = _prod(input.shape) * other.shape[-1]
flops = 2 * macs
return flops, macs
@meta_profiler_function.register(torch.bmm)
def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = _prod(input.shape) * other.shape[-1]
flops = 2 * macs
return flops, macs
@meta_profiler_function.register(torch.var_mean)
def torch_var_mean(input: torch.Tensor,
dim: Union[int, Tuple[int, ...]],
unbiased: Optional[bool] = True,
keepdim: Optional[bool] = False,
*,
out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
assert out is None, 'saving to out is not supported yet'
flops = input.numel() * 3
macs = 0
return flops, macs

View File

@ -0,0 +1,19 @@
import torch
from typing import Optional
from ..registry import meta_profiler_function
@meta_profiler_function.register(torch.nn.functional.embedding)
def torch_nn_functional_embedding(
input: torch.Tensor,
weight: torch.Tensor,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
) -> torch.Tensor:
# F.embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6)
flops = 0
macs = 0
return flops, macs

View File

@ -0,0 +1,13 @@
from typing import Tuple
import torch
from ..registry import meta_profiler_function
@meta_profiler_function.register(torch.nn.functional.linear)
def torch_nn_linear(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None) -> Tuple[int, int]:
out_features = weight.shape[0]
macs = torch.numel(input) * out_features
flops = 2 * macs
if bias is not None:
flops += bias.numel()
return flops, macs

View File

@ -0,0 +1,66 @@
from typing import List, Optional, Tuple
import torch
from ..registry import meta_profiler_function
@meta_profiler_function.register(torch.nn.functional.instance_norm)
def torch_nn_func_instancenorm(
input: torch.Tensor,
running_mean: Optional[torch.Tensor] = None,
running_var: Optional[torch.Tensor] = None,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
use_input_stats: bool = True,
momentum: float = 0.1,
eps: float = 1e-5,
):
has_affine = weight is not None
flops = input.numel() * (5 if has_affine else 4)
macs = 0
return flops, macs
@meta_profiler_function.register(torch.nn.functional.group_norm)
def torch_nn_func_groupnorm(input: torch.Tensor,
num_groups: int,
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
eps: float = 1e-5) -> Tuple[int, int]:
has_affine = weight is not None
flops = input.numel() * (5 if has_affine else 4)
macs = 0
return flops, macs
@meta_profiler_function.register(torch.nn.functional.layer_norm)
def torch_nn_func_layernorm(
input: torch.Tensor,
normalized_shape: List[int],
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
eps: float = 1e-5,
) -> Tuple[int, int]:
has_affine = weight is not None
flops = input.numel() * (5 if has_affine else 4)
macs = 0
return flops, macs
@meta_profiler_function.register(torch.nn.functional.batch_norm)
def torch_nn_func_batchnorm(
input: torch.Tensor,
running_mean: Optional[torch.Tensor],
running_var: Optional[torch.Tensor],
weight: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
training: bool = False,
momentum: float = 0.1,
eps: float = 1e-5,
) -> Tuple[int, int]:
has_affine = weight is not None
if training:
flops = input.numel() * (2 if has_affine else 1)
else:
flops = input.numel() * (5 if has_affine else 4)
macs = 0
return flops, macs

View File

@ -0,0 +1,22 @@
from typing import Tuple, Union
import torch
from ..registry import meta_profiler_function
@meta_profiler_function.register(torch.nn.functional.avg_pool1d)
@meta_profiler_function.register(torch.nn.functional.avg_pool2d)
@meta_profiler_function.register(torch.nn.functional.avg_pool3d)
@meta_profiler_function.register(torch.nn.functional.max_pool1d)
@meta_profiler_function.register(torch.nn.functional.max_pool2d)
@meta_profiler_function.register(torch.nn.functional.max_pool3d)
@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool1d)
@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool2d)
@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool3d)
@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool1d)
@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool2d)
@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool3d)
def torch_nn_func_pooling(input: torch.Tensor, *args, **kwargs) -> Tuple[int, int]:
# all pooling could be considered as going over each input element only once (https://stackoverflow.com/a/67301217)
flops = input.numel()
macs = 0
return flops, macs

View File

@ -0,0 +1,12 @@
import operator
from typing import Any, Tuple
import torch
from ..registry import meta_profiler_function
from colossalai.fx.proxy import ColoProxy
@meta_profiler_function.register(operator.getitem)
def operator_getitem(a: Any, b: Any) -> Tuple[int, int]:
flops = 0
macs = 0
return flops, macs

View File

@ -0,0 +1,64 @@
from typing import Any, Optional, Tuple
import torch
from ..registry import meta_profiler_function
def _prod(dims):
p = 1
for v in dims:
p *= v
return p
@meta_profiler_function.register(torch.arange)
@meta_profiler_function.register(torch.finfo)
@meta_profiler_function.register(torch.permute)
@meta_profiler_function.register(torch.Tensor.permute)
@meta_profiler_function.register(torch.Tensor.repeat)
@meta_profiler_function.register(torch.index_select)
@meta_profiler_function.register(torch.Tensor.index_select)
@meta_profiler_function.register(torch.squeeze)
@meta_profiler_function.register(torch.Tensor.squeeze)
@meta_profiler_function.register(torch.unsqueeze)
@meta_profiler_function.register(torch.Tensor.unsqueeze)
@meta_profiler_function.register(torch.cat)
@meta_profiler_function.register(torch.concat)
@meta_profiler_function.register(torch.repeat_interleave)
@meta_profiler_function.register(torch.Tensor.repeat_interleave)
@meta_profiler_function.register(torch.flatten)
@meta_profiler_function.register(torch.Tensor.flatten)
@meta_profiler_function.register(torch.roll)
@meta_profiler_function.register(torch.full)
@meta_profiler_function.register(torch.Tensor.cpu)
@meta_profiler_function.register(torch.Tensor.cuda)
def torch_zero_flops_op(*args, **kwargs) -> Tuple[int, int]:
flops = 0
macs = 0
return flops, macs
@meta_profiler_function.register(torch.where)
def torch_where(condition: torch.Tensor, x: Any, y: Any) -> Tuple[int, int]:
# torch.where returns the broadcasted tensor of condition, x, and y,
# so hack it by using addition
flops = condition.numel()
macs = 0
return flops, macs
@meta_profiler_function.register(torch.max)
def torch_max(input: torch.Tensor,
dim: int = None,
keepdim: bool = False,
*,
out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = 0
assert out is None, 'assigning value to out is not supported yet'
if dim is not None:
shape = list(input.shape)
shape.pop(int(dim))
flops = _prod(shape), macs
return flops, macs
else:
flops = input.numel()
return flops, macs

View File

@ -0,0 +1,7 @@
from .activation_function import *
from .convolution import *
from .embedding import *
from .linear import *
from .normalization import *
from .pooling import *
from .rnn import *

View File

@ -0,0 +1,29 @@
from typing import Tuple
import torch
from ..registry import meta_profiler_module
# TODO: different activation has different FLOPs count, currently unused.
_multiplier = {
torch.nn.ReLU: 1,
torch.nn.PReLU: 4,
torch.nn.Sigmoid: 4,
torch.nn.Tanh: 5,
torch.nn.LeakyReLU: 3,
torch.nn.ELU: 4,
torch.nn.ReLU6: 2,
torch.nn.GELU: 9,
}
@meta_profiler_module.register(torch.nn.ELU)
@meta_profiler_module.register(torch.nn.LeakyReLU)
@meta_profiler_module.register(torch.nn.ReLU)
@meta_profiler_module.register(torch.nn.GELU)
@meta_profiler_module.register(torch.nn.Sigmoid)
@meta_profiler_module.register(torch.nn.Tanh)
@meta_profiler_module.register(torch.nn.ReLU6)
@meta_profiler_module.register(torch.nn.PReLU)
def torch_nn_non_linear_act(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
flops = input.numel()
macs = 0
return flops, macs

View File

@ -0,0 +1,157 @@
import math
from typing import Tuple
import torch
from ..registry import meta_profiler_module
def _prod(dims):
p = 1
for v in dims:
p *= v
return p
@meta_profiler_module.register(torch.nn.Conv1d)
def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, int]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
c_in, l_in = input.shape[-2:]
c_out = self.out_channels
l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] *
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
result_shape = input.shape[:-2] + (
c_out,
l_out,
)
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
num_elem = _prod(result_shape)
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
flops += num_elem
return flops, macs
@meta_profiler_module.register(torch.nn.Conv2d)
def torch_nn_conv2d(self: torch.nn.Conv2d, input: torch.Tensor) -> Tuple[int, int]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
c_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] *
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] *
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
result_shape = input.shape[:-3] + (
c_out,
h_out,
w_out,
)
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
num_elem = _prod(result_shape)
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
flops += num_elem
return flops, macs
@meta_profiler_module.register(torch.nn.Conv3d)
def torch_nn_conv3d(self: torch.nn.Conv3d, input: torch.Tensor) -> Tuple[int, int]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html
c_in, d_in, h_in, w_in = input.shape[-4:]
c_out = self.out_channels
d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] *
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] *
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] *
(self.kernel_size[2] - 1) - 1) / self.stride[2] + 1)
result_shape = input.shape[:-4] + (
c_out,
d_out,
h_out,
w_out,
)
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
num_elem = _prod(result_shape)
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
flops += num_elem
return flops, macs
@meta_profiler_module.register(torch.nn.ConvTranspose1d)
def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor) -> Tuple[int, int]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html
c_in, l_in = input.shape[-2:]
c_out = self.out_channels
l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
(self.kernel_size[0] - 1) + self.output_padding[0] + 1)
result_shape = input.shape[:-2] + (
c_out,
l_out,
)
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
num_elem = _prod(
input.shape
) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
flops += _prod(result_shape)
return flops, macs
@meta_profiler_module.register(torch.nn.ConvTranspose2d)
def torch_nn_convtranspose2d(self: torch.nn.ConvTranspose2d, input: torch.Tensor) -> Tuple[int, int]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
c_in, h_in, w_in = input.shape[-3:]
c_out = self.out_channels
h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
(self.kernel_size[0] - 1) + self.output_padding[0] + 1)
w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
(self.kernel_size[1] - 1) + self.output_padding[1] + 1)
result_shape = input.shape[:-3] + (
c_out,
h_out,
w_out,
)
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
num_elem = _prod(input.shape)
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
flops += _prod(result_shape)
return flops, macs
@meta_profiler_module.register(torch.nn.ConvTranspose3d)
def torch_nn_convtranspose3d(self: torch.nn.ConvTranspose3d, input: torch.Tensor) -> Tuple[int, int]:
# the output shape is calculated using the formula stated
# at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html
c_in, d_in, h_in, w_in = input.shape[-4:]
c_out = self.out_channels
d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] *
(self.kernel_size[0] - 1) + self.output_padding[0] + 1)
h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] *
(self.kernel_size[1] - 1) + self.output_padding[1] + 1)
w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] *
(self.kernel_size[2] - 1) + self.output_padding[2] + 1)
result_shape = input.shape[:-4] + (
c_out,
d_out,
h_out,
w_out,
)
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
num_elem = _prod(input.shape)
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
flops += _prod(result_shape)
return flops, macs

View File

@ -0,0 +1,11 @@
from typing import Tuple
import torch
from ..registry import meta_profiler_module
@meta_profiler_module.register(torch.nn.Embedding)
def torch_nn_embedding(self: torch.nn.Embedding, input: torch.Tensor) -> Tuple[int, int]:
# nn.Embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6)
flops = 0
macs = 0
return flops, macs

View File

@ -0,0 +1,13 @@
from typing import Tuple
import torch
from ..registry import meta_profiler_module
@meta_profiler_module.register(torch.nn.Linear)
def torch_nn_linear(self: torch.nn.Linear, input: torch.Tensor) -> Tuple[int, int]:
out_features = self.weight.shape[0]
macs = torch.numel(input) * out_features
flops = 2 * macs
if self.bias is not None:
flops += self.bias.numel()
return flops, macs

View File

@ -0,0 +1,33 @@
from typing import Tuple, Union
import torch
from ..registry import meta_profiler_module
@meta_profiler_module.register(torch.nn.InstanceNorm1d)
@meta_profiler_module.register(torch.nn.InstanceNorm2d)
@meta_profiler_module.register(torch.nn.InstanceNorm3d)
@meta_profiler_module.register(torch.nn.LayerNorm)
@meta_profiler_module.register(torch.nn.GroupNorm)
@meta_profiler_module.register(torch.nn.BatchNorm1d)
@meta_profiler_module.register(torch.nn.BatchNorm2d)
@meta_profiler_module.register(torch.nn.BatchNorm3d)
def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d], input: torch.Tensor) -> Tuple[int, int]:
# adopted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L615
has_affine = self.weight is not None
if self.training:
flops = input.numel() * (2 if has_affine else 1)
else:
flops = input.numel() * (5 if has_affine else 4)
macs = 0
return flops, macs
try:
import apex
meta_profiler_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize)
meta_profiler_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize)
meta_profiler_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize)
meta_profiler_module.register(apex.normalization.MixedFusedRMSNorm)(torch_nn_normalize)
except (ImportError, AttributeError):
pass

View File

@ -0,0 +1,22 @@
from typing import Tuple
import torch
from ..registry import meta_profiler_module
@meta_profiler_module.register(torch.nn.AvgPool1d)
@meta_profiler_module.register(torch.nn.AvgPool2d)
@meta_profiler_module.register(torch.nn.AvgPool3d)
@meta_profiler_module.register(torch.nn.MaxPool1d)
@meta_profiler_module.register(torch.nn.MaxPool2d)
@meta_profiler_module.register(torch.nn.MaxPool3d)
@meta_profiler_module.register(torch.nn.AdaptiveAvgPool1d)
@meta_profiler_module.register(torch.nn.AdaptiveMaxPool1d)
@meta_profiler_module.register(torch.nn.AdaptiveAvgPool2d)
@meta_profiler_module.register(torch.nn.AdaptiveMaxPool2d)
@meta_profiler_module.register(torch.nn.AdaptiveAvgPool3d)
@meta_profiler_module.register(torch.nn.AdaptiveMaxPool3d)
def torch_nn_pooling(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
# all pooling could be considered as going over each input element only once (https://stackoverflow.com/a/67301217)
flops = input.numel()
macs = 0
return flops, macs

View File

@ -0,0 +1,13 @@
import torch
from ..registry import meta_profiler_module
from typing import Optional, Tuple
# TODO: calculate rnn FLOPs
@meta_profiler_module.register(torch.nn.GRU)
@meta_profiler_module.register(torch.nn.RNN)
def torch_nn_rnn(self: torch.nn.Module, input: torch.Tensor, hx: torch.Tensor) -> Tuple[int, int]:
raise NotImplementedError
flops = 0
macs = 0
return flops, macs

View File

@ -0,0 +1,25 @@
class ProfilerRegistry:
def __init__(self, name):
self.name = name
self.store = {}
def register(self, source):
def wrapper(func):
self.store[source] = func
return func
return wrapper
def get(self, source):
assert source in self.store
target = self.store[source]
return target
def has(self, source):
return source in self.store
meta_profiler_function = ProfilerRegistry(name='patched_functions_for_meta_profile')
meta_profiler_module = ProfilerRegistry(name='patched_modules_for_meta_profile')

View File

@ -0,0 +1,180 @@
from functools import partial
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
from typing import Callable, NamedTuple, Any, Dict, Tuple
import torch
from torch.fx.node import Argument, Target
from torch.fx._compatibility import compatibility
from colossalai.fx.tracer.meta_patch import meta_patched_function, meta_patched_module
from . import meta_profiler_function, meta_profiler_module
__all__ = [
'MetaProfile', 'profile_function', 'profile_module', 'profile_method', 'calculate_activation_size',
'calculate_param_size'
]
# TODO fill out the inplace ops
INPLACE_OPS = [
add,
sub,
mul,
floordiv,
neg,
pos,
getitem,
setitem,
torch.Tensor.cpu,
]
# TODO check that call_methods are indeed inplace
INPLACE_METHOD = [
'transpose',
'permute',
]
@compatibility(is_backward_compatible=True)
class MetaProfile(NamedTuple):
# MetaProfile is a structure containing pertinent information
# about a node within a torch.fx GraphModule.
param: int
activation: int
flops: int
macs: int
def calculate_activation_size(activation: any) -> int:
"""
Calculate activation size of a node.
"""
activation_size = 0
if isinstance(activation, torch.Tensor):
activation_size += activation.numel() * torch.tensor([], dtype=activation.dtype).element_size()
elif isinstance(activation, dict):
value_list = [v for _, v in activation.items()]
activation_size += calculate_activation_size(value_list)
else:
for element in activation:
activation_size += calculate_activation_size(element)
return activation_size
def calculate_param_size(mod: torch.nn.Module) -> int:
"""
Calculate param size of a node.
"""
param_size = 0
for param in mod.parameters():
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
return param_size
def profile_function(target: 'Target') -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn.functional` are available.
Usage:
input = torch.rand(100, 100, 100, 100, device='meta')
func = torch.nn.functional.relu
output, profile = profile_function(func)(input, inplace=False)
print(f"Profiling function {func},")
print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_function.has(target) or meta_profiler_function.has(
target.__name__), f"Colossal-AI hasn't supported profiling for {target}, you might manually patch it."
# call_function has no parameters
param_size = 0
activation_size = 0
result = func(*args, **kwargs)
if target not in INPLACE_OPS and not kwargs.get('inplace', False):
activation_size += calculate_activation_size(result)
if meta_profiler_function.has(target):
profiler = meta_profiler_function.get(target)
else:
profiler = meta_profiler_function.get(target.__name__)
flops, macs = profiler(*args, **kwargs)
return result, MetaProfile(param_size, activation_size, flops, macs)
f.__name__ = target.__name__
# fetch patched function
if meta_patched_function.has(target):
func = meta_patched_function.get(target)
elif meta_patched_function.has(target.__name__):
func = meta_patched_function.get(target.__name__)
else:
func = target
return f
def profile_method(target: 'Target') -> Callable:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.
Warnings:
This is not fully implemented and you may follow the error message to debug.
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
# Execute the method and return the result
assert isinstance(target, str), f'{target} instance is not str.'
result = getattr(self_obj, target)(*args_tail, **kwargs)
assert target in INPLACE_METHOD, f'Please check {target} is an inplace method. If so, add target to INPLACE_METHOD={INPLACE_METHOD}.'
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
param_size = 0
activation_size = 0
flops = 0
macs = 0
return result, MetaProfile(param_size, activation_size, flops, macs)
return f
def profile_module(module: torch.nn.Module) -> Callable:
"""
Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn` are available.
Usage:
input = torch.rand(4, 3, 224, 224, device='meta')
mod = torch.nn.Conv2d(3, 128, 3)
output, profile = profile_module(mod)(input)
print(f"Profiling function {mod},")
print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_module.has(
type(module)), f"Colossal-AI hasn't supported profiling for {module}, you might manually patch it."
param_size = calculate_param_size(module)
activation_size = 0
result = func(*args, **kwargs)
if not getattr(module, 'inplace', False):
activation_size += calculate_activation_size(result)
profiler = meta_profiler_module.get(type(module))
flops, macs = profiler(module, *args, **kwargs)
return result, MetaProfile(param_size, activation_size, flops, macs)
f.__name__ = module.__class__.__name__
# fetch patched module
if meta_patched_module.has(type(module)):
func = partial(meta_patched_module.get(type(module)), module)
else:
func = module.forward
return f

View File

@ -68,7 +68,7 @@ def _run_ckpt_solver(rank):
tracer = ColoTracer(trace_act_ckpt=False)
data = torch.rand(2, 3, 32, 32)
data = torch.rand(2, 3, 32, 32, device='meta')
for solver in SOLVERS:
for model_cls in MODEL_LIST:
m = model_cls(num_classes=5)
@ -98,7 +98,7 @@ def _run_ckpt_solver_torch11(rank):
tracer = ColoTracer(trace_act_ckpt=False)
data = torch.rand(2, 3, 32, 32)
data = torch.rand(2, 3, 32, 32, device='meta')
for solver in SOLVERS:
for model_cls in MODEL_LIST:
m = model_cls(num_classes=5)

View File

@ -32,7 +32,7 @@ class MLP(torch.nn.Module):
def test_comm_size_compute():
model = MLP(MODEL_DIM)
input_sample = torch.rand(BATCH_SIZE, MODEL_DIM)
input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta')
gm = symbolic_trace(model)
MetaInfoProp(gm).run(input_sample)
annotated_model = uniform_split_pass(gm, PIPELINE_SIZE)

View File

@ -20,17 +20,20 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor):
def test_meta_info_prop():
model = torch.nn.Linear(DIM_IN, DIM_OUT)
input_sample = torch.rand(BATCH_SIZE, DIM_IN)
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta')
orig_output = model(input_sample)
gm = symbolic_trace(model)
for node in gm.graph.nodes:
assert not hasattr(node,
'node_size'), 'The attribute Node.node_size should not exist before MetaInfoProp procedure'
assert not hasattr(node,
'param_size'), 'The attribute Node.param_size should not exist before MetaInfoProp procedure'
'__param__'), 'The attribute Node.__param__ should not exist before MetaInfoProp procedure'
assert not hasattr(
node,
'activation_size'), 'The attribute Node.activation_size should not exist before MetaInfoProp procedure'
node, '__activation__'), 'The attribute Node.__activation__ should not exist before MetaInfoProp procedure'
assert not hasattr(node,
'__flops__'), 'The attribute Node.__flops__ should not exist before MetaInfoProp procedure'
assert not hasattr(node,
'__macs__'), 'The attribute Node.__macs__ should not exist before MetaInfoProp procedure'
MetaInfoProp(gm).run(input_sample)
for node in gm.graph.nodes:
if node.op == 'placeholder':
@ -38,9 +41,11 @@ def test_meta_info_prop():
if node.op == 'output':
meta_check(node.meta['tensor_meta'], orig_output)
assert hasattr(node, 'node_size'), 'The attribute Node.node_size should exist after MetaInfoProp procedure'
assert hasattr(node, 'param_size'), 'The attribute Node.param_size should exist after MetaInfoProp procedure'
assert hasattr(
node, 'activation_size'), 'The attribute Node.activation_size should exist after MetaInfoProp procedure'
assert hasattr(node, '__param__'), 'The attribute Node.__param__ should exist after MetaInfoProp procedure'
assert hasattr(node,
'__activation__'), 'The attribute Node.__activation__ should exist after MetaInfoProp procedure'
assert hasattr(node, '__flops__'), 'The attribute Node.__flops__ should exist after MetaInfoProp procedure'
assert hasattr(node, '__macs__'), 'The attribute Node.__macs__ should exist after MetaInfoProp procedure'
if __name__ == '__main__':