diff --git a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py index 5f665aae5..54d22a538 100644 --- a/colossalai/fx/passes/algorithms/ckpt_solver_chen.py +++ b/colossalai/fx/passes/algorithms/ckpt_solver_chen.py @@ -73,10 +73,10 @@ def chen_greedy(gm: GraphModule) -> GraphModule: y = 0 prev_idx = 2 for (idx, n) in enumerate(gm.graph.nodes): - temp += getattr(n, 'activation_size') + temp += getattr(n, '__activation__') y = max(y, temp) if temp > b and n in ckpt_nodes: - x += getattr(n, 'activation_size') + x += getattr(n, '__activation__') temp = 0 ckpt_intv.append((prev_idx, idx + 1)) prev_idx = idx + 1 diff --git a/colossalai/fx/passes/meta_info_prop.py b/colossalai/fx/passes/meta_info_prop.py index f8fa60249..7f7377667 100644 --- a/colossalai/fx/passes/meta_info_prop.py +++ b/colossalai/fx/passes/meta_info_prop.py @@ -1,10 +1,12 @@ +from operator import add, getitem import torch import torch.fx -from torch.fx.node import Node, map_aggregate +from torch.fx.node import Node, map_aggregate, Argument, Target from typing import Any, Tuple, NamedTuple, Optional, Dict from functools import reduce from torch.fx._compatibility import compatibility from torch.fx.immutable_collections import immutable_dict, immutable_list +from colossalai.fx.profiler import MetaProfile, profile_function, profile_module, calculate_activation_size, profile_method @compatibility(is_backward_compatible=True) @@ -36,47 +38,11 @@ def _extract_tensor_metadata(result: torch.Tensor) -> TensorMetadata: return TensorMetadata(shape, dtype, requires_grad, stride, numel, is_tensor) -def _compute_activation_size(node_metadata: any) -> int: - """ - Compute numel of a node with ``tensor_meta`` attribute. - """ - node_numel = 0 - - if isinstance(node_metadata, TensorMetadata): - node_numel += node_metadata.numel * torch.tensor([], dtype=node_metadata.dtype).element_size() - elif isinstance(node_metadata, dict): - value_list = [v for _, v in node_metadata.items()] - node_numel += _compute_activation_size(value_list) - else: - for element in node_metadata: - node_numel += _compute_activation_size(element) - - return node_numel - - -def _map_aggregate(arg, fn): - """ - Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. - """ - if isinstance(arg, torch.Size): - return fn(arg) - if isinstance(arg, tuple): - return tuple(map_aggregate(elem, fn) for elem in arg) - elif isinstance(arg, list): - return immutable_list(map_aggregate(elem, fn) for elem in arg) - elif isinstance(arg, dict): - return immutable_dict((k, map_aggregate(v, fn)) for k, v in arg.items()) - elif isinstance(arg, slice): - return slice(map_aggregate(arg.start, fn), map_aggregate(arg.stop, fn), map_aggregate(arg.step, fn)) - else: - return fn(arg) - - @compatibility(is_backward_compatible=True) class MetaInfoProp(torch.fx.Interpreter): """ - Execute an FX graph Node-by-Node and - record the shape and type of the result + Execute an FX graph Node-by-Node with meta tensor and + record the shape, FLOPs, MACs and type of the result into the corresponding node. Usage: @@ -104,9 +70,32 @@ class MetaInfoProp(torch.fx.Interpreter): """ + @compatibility(is_backward_compatible=True) + def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any: + """ + Add additional check for initial args to ensure all the tensor appears with `device='meta'` + """ + for elem in args: + if isinstance(elem, torch.Tensor): + assert elem.is_meta, "Input torch.Tensor are assumed to appear with device='meta'" + return super().run(*args, initial_env, enable_io_processing) + + @compatibility(is_backward_compatible=True) def run_node(self, n: Node) -> Any: - # TODO: We might run_node(n) with meta data, and count FLOPS for each node - result = super().run_node(n) + """ + Run a specific node ``n`` and return the result. + Calls into placeholder, get_attr, call_function, + call_method, call_module, or output depending + on ``node.op`` + + Args: + n (Node): The Node to execute + + Returns: + Any: The result of executing ``n`` + """ + result, profile = super().run_node(n) + profile: MetaProfile def extract_tensor_meta(obj): if isinstance(obj, torch.Tensor): @@ -114,29 +103,139 @@ class MetaInfoProp(torch.fx.Interpreter): else: return TensorMetadata(None, None, False, None, 0, False) - meta = _map_aggregate(result, extract_tensor_meta) + meta = map_aggregate(result, extract_tensor_meta) n.meta['tensor_meta'] = meta - total_activation_size = 0 - total_param_size = 0 - if n.op == 'call_module': - target_module = n.graph.owning_module.get_submodule(n.target) - if not getattr(target_module, 'inplace', False): - total_activation_size = _compute_activation_size(n.meta['tensor_meta']) - for param in target_module.parameters(): - total_param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size() - elif n.op == 'call_function': - if 'inplace' not in n.kwargs: - total_activation_size = _compute_activation_size(n.meta['tensor_meta']) - else: - total_activation_size = _compute_activation_size(n.meta['tensor_meta']) - - setattr(n, 'node_size', total_activation_size + total_param_size) - setattr(n, 'param_size', total_param_size) - setattr(n, 'activation_size', total_activation_size) + # TODO: the attribute node_size should be removed in the future + setattr(n, 'node_size', profile.param + profile.activation) + setattr(n, '__param__', profile.param) + setattr(n, '__activation__', profile.activation) + setattr(n, '__flops__', profile.flops) + setattr(n, '__macs__', profile.macs) n.meta['type'] = type(result) return result + # Main Node running APIs + @compatibility(is_backward_compatible=True) + def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``placeholder`` node. Note that this is stateful: + ``Interpreter`` maintains an internal iterator over + arguments passed to ``run`` and this method returns + next() on that iterator. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Returns: + result (Any): The argument value that was retrieved + profile (MetaProfile): The meta profile of this node + """ + result = super().placeholder(target, args, kwargs) + # A placeholder node only has activation + return result, MetaProfile(0, calculate_activation_size(result), 0, 0) + + @compatibility(is_backward_compatible=True) + def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``get_attr`` node. Will retrieve an attribute + value from the ``Module`` hierarchy of ``self.module``. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + result (Any): The argument value that was retrieved + profile (MetaProfile): The meta profile of this node + """ + # A get_attr node never has parameters, activations, FLOPs, or MACs + return super().get_attr(target, args, kwargs), MetaProfile(0, 0, 0, 0) + + @compatibility(is_backward_compatible=True) + def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_function`` node with meta tensor and return the result and its meta profile. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + result (Any): The argument value that was retrieved + profile (MetaProfile): The meta profile of this node + """ + assert not isinstance(target, str) + return profile_function(target)(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def call_method(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_method`` node with meta tensor and return the result and its meta profile. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + result (Any): The argument value that was retrieved + profile (MetaProfile): The meta profile of this node + """ + return profile_method(target)(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def call_module(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute a ``call_module`` node with meta tensor and return the result and its meta profile. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + result (Any): The argument value that was retrieved + profile (MetaProfile): The meta profile of this node + """ + # Retrieve executed args and kwargs values from the environment + # Execute the method and return the result + assert isinstance(target, str) + submod = self.fetch_attr(target) + return profile_module(submod)(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def output(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any: + """ + Execute an ``output`` node. This really just retrieves + the value referenced by the ``output`` node and returns it. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + Any: The return value referenced by the output node + """ + return args[0], MetaProfile(0, 0, 0, 0) + def propagate(self, *args): """ Run `module` via interpretation and return the result and diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py new file mode 100644 index 000000000..353c600c7 --- /dev/null +++ b/colossalai/fx/profiler/__init__.py @@ -0,0 +1,4 @@ +from .registry import * +from .profiler_function import * +from .profiler_module import * +from .utils import * diff --git a/colossalai/fx/profiler/profiler_function/__init__.py b/colossalai/fx/profiler/profiler_function/__init__.py new file mode 100644 index 000000000..bf77edba8 --- /dev/null +++ b/colossalai/fx/profiler/profiler_function/__init__.py @@ -0,0 +1,8 @@ +from .activation_function import * +from .arithmetic import * +from .embedding import * +from .linear import * +from .normalization import * +from .pooling import * +from .python_ops import * +from .torch_ops import * diff --git a/colossalai/fx/profiler/profiler_function/activation_function.py b/colossalai/fx/profiler/profiler_function/activation_function.py new file mode 100644 index 000000000..0bf5d8af9 --- /dev/null +++ b/colossalai/fx/profiler/profiler_function/activation_function.py @@ -0,0 +1,29 @@ +from typing import Tuple +import torch +from ..registry import meta_profiler_function + +# TODO: different activation has different FLOPs count, currently unused. +_multiplier = { + torch.nn.functional.relu: 1, + torch.nn.functional.prelu: 4, + torch.nn.functional.sigmoid: 4, + torch.nn.functional.tanh: 5, + torch.nn.functional.leaky_relu: 3, + torch.nn.functional.elu: 4, + torch.nn.functional.relu6: 2, + torch.nn.functional.gelu: 9, +} + + +@meta_profiler_function.register(torch.nn.functional.leaky_relu) +@meta_profiler_function.register(torch.nn.functional.elu) +@meta_profiler_function.register(torch.nn.functional.gelu) +@meta_profiler_function.register(torch.nn.functional.relu6) +@meta_profiler_function.register(torch.nn.functional.prelu) +@meta_profiler_function.register(torch.nn.functional.relu) +@meta_profiler_function.register(torch.nn.functional.sigmoid) +@meta_profiler_function.register(torch.nn.functional.tanh) +def torch_nn_func_non_linear_act(input: torch.Tensor, inplace: bool = False) -> Tuple[int, int]: + flops = input.numel() + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/profiler_function/arithmetic.py b/colossalai/fx/profiler/profiler_function/arithmetic.py new file mode 100644 index 000000000..b52c56557 --- /dev/null +++ b/colossalai/fx/profiler/profiler_function/arithmetic.py @@ -0,0 +1,83 @@ +from typing import Any, Optional, Tuple, Union +import torch +from ..registry import meta_profiler_function + + +def _prod(dims): + p = 1 + for v in dims: + p *= v + return p + + +def _elementwise_flops_compute(input, other): + # copied from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L763 + if not torch.is_tensor(input): + if torch.is_tensor(other): + return _prod(other.shape), 0 + else: + return 1, 0 + elif not torch.is_tensor(other): + return _prod(input.shape), 0 + else: + dim_input = len(input.shape) + dim_other = len(other.shape) + max_dim = max(dim_input, dim_other) + + final_shape = [] + for i in range(max_dim): + in_i = input.shape[i] if i < dim_input else 1 + ot_i = other.shape[i] if i < dim_other else 1 + if in_i > ot_i: + final_shape.append(in_i) + else: + final_shape.append(ot_i) + flops = _prod(final_shape) + return flops, 0 + + +@meta_profiler_function.register(torch.add) +@meta_profiler_function.register('add') # for built-in op + +@meta_profiler_function.register('iadd') # for built-in op += +@meta_profiler_function.register('sub') # for built-in op - +@meta_profiler_function.register('isub') # for built-in op -= +@meta_profiler_function.register('mul') # for built-in op * +@meta_profiler_function.register('imul') # for built-in op *= +def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: + return _elementwise_flops_compute(input, other) + + +@meta_profiler_function.register(torch.abs) +def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: + flops = input.numel() + macs = 0 + return flops, macs + + +@meta_profiler_function.register(torch.matmul) +@meta_profiler_function.register('matmul') # for built-in op @ +@meta_profiler_function.register(torch.Tensor.matmul) +def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: + macs = _prod(input.shape) * other.shape[-1] + flops = 2 * macs + return flops, macs + + +@meta_profiler_function.register(torch.bmm) +def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: + macs = _prod(input.shape) * other.shape[-1] + flops = 2 * macs + return flops, macs + + +@meta_profiler_function.register(torch.var_mean) +def torch_var_mean(input: torch.Tensor, + dim: Union[int, Tuple[int, ...]], + unbiased: Optional[bool] = True, + keepdim: Optional[bool] = False, + *, + out: Optional[torch.Tensor] = None) -> Tuple[int, int]: + assert out is None, 'saving to out is not supported yet' + flops = input.numel() * 3 + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/profiler_function/embedding.py b/colossalai/fx/profiler/profiler_function/embedding.py new file mode 100644 index 000000000..d6e43d781 --- /dev/null +++ b/colossalai/fx/profiler/profiler_function/embedding.py @@ -0,0 +1,19 @@ +import torch +from typing import Optional +from ..registry import meta_profiler_function + + +@meta_profiler_function.register(torch.nn.functional.embedding) +def torch_nn_functional_embedding( + input: torch.Tensor, + weight: torch.Tensor, + padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, +) -> torch.Tensor: + # F.embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6) + flops = 0 + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/profiler_function/linear.py b/colossalai/fx/profiler/profiler_function/linear.py new file mode 100644 index 000000000..01fe4c871 --- /dev/null +++ b/colossalai/fx/profiler/profiler_function/linear.py @@ -0,0 +1,13 @@ +from typing import Tuple +import torch +from ..registry import meta_profiler_function + + +@meta_profiler_function.register(torch.nn.functional.linear) +def torch_nn_linear(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor = None) -> Tuple[int, int]: + out_features = weight.shape[0] + macs = torch.numel(input) * out_features + flops = 2 * macs + if bias is not None: + flops += bias.numel() + return flops, macs diff --git a/colossalai/fx/profiler/profiler_function/normalization.py b/colossalai/fx/profiler/profiler_function/normalization.py new file mode 100644 index 000000000..c4ea508d7 --- /dev/null +++ b/colossalai/fx/profiler/profiler_function/normalization.py @@ -0,0 +1,66 @@ +from typing import List, Optional, Tuple +import torch +from ..registry import meta_profiler_function + + +@meta_profiler_function.register(torch.nn.functional.instance_norm) +def torch_nn_func_instancenorm( + input: torch.Tensor, + running_mean: Optional[torch.Tensor] = None, + running_var: Optional[torch.Tensor] = None, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + use_input_stats: bool = True, + momentum: float = 0.1, + eps: float = 1e-5, +): + has_affine = weight is not None + flops = input.numel() * (5 if has_affine else 4) + macs = 0 + return flops, macs + + +@meta_profiler_function.register(torch.nn.functional.group_norm) +def torch_nn_func_groupnorm(input: torch.Tensor, + num_groups: int, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + eps: float = 1e-5) -> Tuple[int, int]: + has_affine = weight is not None + flops = input.numel() * (5 if has_affine else 4) + macs = 0 + return flops, macs + + +@meta_profiler_function.register(torch.nn.functional.layer_norm) +def torch_nn_func_layernorm( + input: torch.Tensor, + normalized_shape: List[int], + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + eps: float = 1e-5, +) -> Tuple[int, int]: + has_affine = weight is not None + flops = input.numel() * (5 if has_affine else 4) + macs = 0 + return flops, macs + + +@meta_profiler_function.register(torch.nn.functional.batch_norm) +def torch_nn_func_batchnorm( + input: torch.Tensor, + running_mean: Optional[torch.Tensor], + running_var: Optional[torch.Tensor], + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + training: bool = False, + momentum: float = 0.1, + eps: float = 1e-5, +) -> Tuple[int, int]: + has_affine = weight is not None + if training: + flops = input.numel() * (2 if has_affine else 1) + else: + flops = input.numel() * (5 if has_affine else 4) + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/profiler_function/pooling.py b/colossalai/fx/profiler/profiler_function/pooling.py new file mode 100644 index 000000000..a639f5ee8 --- /dev/null +++ b/colossalai/fx/profiler/profiler_function/pooling.py @@ -0,0 +1,22 @@ +from typing import Tuple, Union +import torch +from ..registry import meta_profiler_function + + +@meta_profiler_function.register(torch.nn.functional.avg_pool1d) +@meta_profiler_function.register(torch.nn.functional.avg_pool2d) +@meta_profiler_function.register(torch.nn.functional.avg_pool3d) +@meta_profiler_function.register(torch.nn.functional.max_pool1d) +@meta_profiler_function.register(torch.nn.functional.max_pool2d) +@meta_profiler_function.register(torch.nn.functional.max_pool3d) +@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool1d) +@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool2d) +@meta_profiler_function.register(torch.nn.functional.adaptive_avg_pool3d) +@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool1d) +@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool2d) +@meta_profiler_function.register(torch.nn.functional.adaptive_max_pool3d) +def torch_nn_func_pooling(input: torch.Tensor, *args, **kwargs) -> Tuple[int, int]: + # all pooling could be considered as going over each input element only once (https://stackoverflow.com/a/67301217) + flops = input.numel() + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/profiler_function/python_ops.py b/colossalai/fx/profiler/profiler_function/python_ops.py new file mode 100644 index 000000000..95c67c47e --- /dev/null +++ b/colossalai/fx/profiler/profiler_function/python_ops.py @@ -0,0 +1,12 @@ +import operator +from typing import Any, Tuple +import torch +from ..registry import meta_profiler_function +from colossalai.fx.proxy import ColoProxy + + +@meta_profiler_function.register(operator.getitem) +def operator_getitem(a: Any, b: Any) -> Tuple[int, int]: + flops = 0 + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/profiler_function/torch_ops.py b/colossalai/fx/profiler/profiler_function/torch_ops.py new file mode 100644 index 000000000..f67090b23 --- /dev/null +++ b/colossalai/fx/profiler/profiler_function/torch_ops.py @@ -0,0 +1,64 @@ +from typing import Any, Optional, Tuple +import torch +from ..registry import meta_profiler_function + + +def _prod(dims): + p = 1 + for v in dims: + p *= v + return p + + +@meta_profiler_function.register(torch.arange) +@meta_profiler_function.register(torch.finfo) +@meta_profiler_function.register(torch.permute) +@meta_profiler_function.register(torch.Tensor.permute) +@meta_profiler_function.register(torch.Tensor.repeat) +@meta_profiler_function.register(torch.index_select) +@meta_profiler_function.register(torch.Tensor.index_select) +@meta_profiler_function.register(torch.squeeze) +@meta_profiler_function.register(torch.Tensor.squeeze) +@meta_profiler_function.register(torch.unsqueeze) +@meta_profiler_function.register(torch.Tensor.unsqueeze) +@meta_profiler_function.register(torch.cat) +@meta_profiler_function.register(torch.concat) +@meta_profiler_function.register(torch.repeat_interleave) +@meta_profiler_function.register(torch.Tensor.repeat_interleave) +@meta_profiler_function.register(torch.flatten) +@meta_profiler_function.register(torch.Tensor.flatten) +@meta_profiler_function.register(torch.roll) +@meta_profiler_function.register(torch.full) +@meta_profiler_function.register(torch.Tensor.cpu) +@meta_profiler_function.register(torch.Tensor.cuda) +def torch_zero_flops_op(*args, **kwargs) -> Tuple[int, int]: + flops = 0 + macs = 0 + return flops, macs + + +@meta_profiler_function.register(torch.where) +def torch_where(condition: torch.Tensor, x: Any, y: Any) -> Tuple[int, int]: + # torch.where returns the broadcasted tensor of condition, x, and y, + # so hack it by using addition + flops = condition.numel() + macs = 0 + return flops, macs + + +@meta_profiler_function.register(torch.max) +def torch_max(input: torch.Tensor, + dim: int = None, + keepdim: bool = False, + *, + out: Optional[torch.Tensor] = None) -> Tuple[int, int]: + macs = 0 + assert out is None, 'assigning value to out is not supported yet' + if dim is not None: + shape = list(input.shape) + shape.pop(int(dim)) + flops = _prod(shape), macs + return flops, macs + else: + flops = input.numel() + return flops, macs diff --git a/colossalai/fx/profiler/profiler_module/__init__.py b/colossalai/fx/profiler/profiler_module/__init__.py new file mode 100644 index 000000000..3f40ec2a6 --- /dev/null +++ b/colossalai/fx/profiler/profiler_module/__init__.py @@ -0,0 +1,7 @@ +from .activation_function import * +from .convolution import * +from .embedding import * +from .linear import * +from .normalization import * +from .pooling import * +from .rnn import * diff --git a/colossalai/fx/profiler/profiler_module/activation_function.py b/colossalai/fx/profiler/profiler_module/activation_function.py new file mode 100644 index 000000000..1008eef0a --- /dev/null +++ b/colossalai/fx/profiler/profiler_module/activation_function.py @@ -0,0 +1,29 @@ +from typing import Tuple +import torch +from ..registry import meta_profiler_module + +# TODO: different activation has different FLOPs count, currently unused. +_multiplier = { + torch.nn.ReLU: 1, + torch.nn.PReLU: 4, + torch.nn.Sigmoid: 4, + torch.nn.Tanh: 5, + torch.nn.LeakyReLU: 3, + torch.nn.ELU: 4, + torch.nn.ReLU6: 2, + torch.nn.GELU: 9, +} + + +@meta_profiler_module.register(torch.nn.ELU) +@meta_profiler_module.register(torch.nn.LeakyReLU) +@meta_profiler_module.register(torch.nn.ReLU) +@meta_profiler_module.register(torch.nn.GELU) +@meta_profiler_module.register(torch.nn.Sigmoid) +@meta_profiler_module.register(torch.nn.Tanh) +@meta_profiler_module.register(torch.nn.ReLU6) +@meta_profiler_module.register(torch.nn.PReLU) +def torch_nn_non_linear_act(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]: + flops = input.numel() + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/profiler_module/convolution.py b/colossalai/fx/profiler/profiler_module/convolution.py new file mode 100644 index 000000000..12e3d7e2f --- /dev/null +++ b/colossalai/fx/profiler/profiler_module/convolution.py @@ -0,0 +1,157 @@ +import math +from typing import Tuple +import torch +from ..registry import meta_profiler_module + + +def _prod(dims): + p = 1 + for v in dims: + p *= v + return p + + +@meta_profiler_module.register(torch.nn.Conv1d) +def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, int]: + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + c_in, l_in = input.shape[-2:] + c_out = self.out_channels + l_out = math.floor((l_in + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + result_shape = input.shape[:-2] + ( + c_out, + l_out, + ) + macs_per_elem = _prod(self.kernel_size) * c_in // self.groups + num_elem = _prod(result_shape) + macs = macs_per_elem * num_elem + flops = 2 * macs + if self.bias is not None: + flops += num_elem + return flops, macs + + +@meta_profiler_module.register(torch.nn.Conv2d) +def torch_nn_conv2d(self: torch.nn.Conv2d, input: torch.Tensor) -> Tuple[int, int]: + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html + c_in, h_in, w_in = input.shape[-3:] + c_out = self.out_channels + h_out = math.floor((h_in + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + w_out = math.floor((w_in + 2 * self.padding[1] - self.dilation[1] * + (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) + result_shape = input.shape[:-3] + ( + c_out, + h_out, + w_out, + ) + macs_per_elem = _prod(self.kernel_size) * c_in // self.groups + num_elem = _prod(result_shape) + macs = macs_per_elem * num_elem + flops = 2 * macs + if self.bias is not None: + flops += num_elem + return flops, macs + + +@meta_profiler_module.register(torch.nn.Conv3d) +def torch_nn_conv3d(self: torch.nn.Conv3d, input: torch.Tensor) -> Tuple[int, int]: + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.Conv3d.html + c_in, d_in, h_in, w_in = input.shape[-4:] + c_out = self.out_channels + d_out = math.floor((d_in + 2 * self.padding[0] - self.dilation[0] * + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + h_out = math.floor((h_in + 2 * self.padding[1] - self.dilation[1] * + (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) + w_out = math.floor((w_in + 2 * self.padding[2] - self.dilation[2] * + (self.kernel_size[2] - 1) - 1) / self.stride[2] + 1) + result_shape = input.shape[:-4] + ( + c_out, + d_out, + h_out, + w_out, + ) + macs_per_elem = _prod(self.kernel_size) * c_in // self.groups + num_elem = _prod(result_shape) + macs = macs_per_elem * num_elem + flops = 2 * macs + if self.bias is not None: + flops += num_elem + return flops, macs + + +@meta_profiler_module.register(torch.nn.ConvTranspose1d) +def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor) -> Tuple[int, int]: + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose1d.html + c_in, l_in = input.shape[-2:] + c_out = self.out_channels + l_out = math.floor((l_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * + (self.kernel_size[0] - 1) + self.output_padding[0] + 1) + result_shape = input.shape[:-2] + ( + c_out, + l_out, + ) + macs_per_elem = _prod(self.kernel_size) * c_in // self.groups + num_elem = _prod( + input.shape + ) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604 + macs = macs_per_elem * num_elem + flops = 2 * macs + if self.bias is not None: + flops += _prod(result_shape) + return flops, macs + + +@meta_profiler_module.register(torch.nn.ConvTranspose2d) +def torch_nn_convtranspose2d(self: torch.nn.ConvTranspose2d, input: torch.Tensor) -> Tuple[int, int]: + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html + c_in, h_in, w_in = input.shape[-3:] + c_out = self.out_channels + h_out = math.floor((h_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * + (self.kernel_size[0] - 1) + self.output_padding[0] + 1) + w_out = math.floor((w_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * + (self.kernel_size[1] - 1) + self.output_padding[1] + 1) + result_shape = input.shape[:-3] + ( + c_out, + h_out, + w_out, + ) + macs_per_elem = _prod(self.kernel_size) * c_in // self.groups + num_elem = _prod(input.shape) + macs = macs_per_elem * num_elem + flops = 2 * macs + if self.bias is not None: + flops += _prod(result_shape) + return flops, macs + + +@meta_profiler_module.register(torch.nn.ConvTranspose3d) +def torch_nn_convtranspose3d(self: torch.nn.ConvTranspose3d, input: torch.Tensor) -> Tuple[int, int]: + # the output shape is calculated using the formula stated + # at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose3d.html + c_in, d_in, h_in, w_in = input.shape[-4:] + c_out = self.out_channels + d_out = math.floor((d_in - 1) * self.stride[0] - 2 * self.padding[0] + self.dilation[0] * + (self.kernel_size[0] - 1) + self.output_padding[0] + 1) + h_out = math.floor((h_in - 1) * self.stride[1] - 2 * self.padding[1] + self.dilation[1] * + (self.kernel_size[1] - 1) + self.output_padding[1] + 1) + w_out = math.floor((w_in - 1) * self.stride[2] - 2 * self.padding[2] + self.dilation[2] * + (self.kernel_size[2] - 1) + self.output_padding[2] + 1) + result_shape = input.shape[:-4] + ( + c_out, + d_out, + h_out, + w_out, + ) + macs_per_elem = _prod(self.kernel_size) * c_in // self.groups + num_elem = _prod(input.shape) + macs = macs_per_elem * num_elem + flops = 2 * macs + if self.bias is not None: + flops += _prod(result_shape) + return flops, macs diff --git a/colossalai/fx/profiler/profiler_module/embedding.py b/colossalai/fx/profiler/profiler_module/embedding.py new file mode 100644 index 000000000..dca6f9453 --- /dev/null +++ b/colossalai/fx/profiler/profiler_module/embedding.py @@ -0,0 +1,11 @@ +from typing import Tuple +import torch +from ..registry import meta_profiler_module + + +@meta_profiler_module.register(torch.nn.Embedding) +def torch_nn_embedding(self: torch.nn.Embedding, input: torch.Tensor) -> Tuple[int, int]: + # nn.Embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6) + flops = 0 + macs = 0 + return flops, macs \ No newline at end of file diff --git a/colossalai/fx/profiler/profiler_module/linear.py b/colossalai/fx/profiler/profiler_module/linear.py new file mode 100644 index 000000000..f657f9ac7 --- /dev/null +++ b/colossalai/fx/profiler/profiler_module/linear.py @@ -0,0 +1,13 @@ +from typing import Tuple +import torch +from ..registry import meta_profiler_module + + +@meta_profiler_module.register(torch.nn.Linear) +def torch_nn_linear(self: torch.nn.Linear, input: torch.Tensor) -> Tuple[int, int]: + out_features = self.weight.shape[0] + macs = torch.numel(input) * out_features + flops = 2 * macs + if self.bias is not None: + flops += self.bias.numel() + return flops, macs diff --git a/colossalai/fx/profiler/profiler_module/normalization.py b/colossalai/fx/profiler/profiler_module/normalization.py new file mode 100644 index 000000000..e9939da7b --- /dev/null +++ b/colossalai/fx/profiler/profiler_module/normalization.py @@ -0,0 +1,33 @@ +from typing import Tuple, Union +import torch +from ..registry import meta_profiler_module + + +@meta_profiler_module.register(torch.nn.InstanceNorm1d) +@meta_profiler_module.register(torch.nn.InstanceNorm2d) +@meta_profiler_module.register(torch.nn.InstanceNorm3d) +@meta_profiler_module.register(torch.nn.LayerNorm) +@meta_profiler_module.register(torch.nn.GroupNorm) +@meta_profiler_module.register(torch.nn.BatchNorm1d) +@meta_profiler_module.register(torch.nn.BatchNorm2d) +@meta_profiler_module.register(torch.nn.BatchNorm3d) +def torch_nn_normalize(self: Union[torch.nn.LayerNorm, torch.nn.GroupNorm, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d], input: torch.Tensor) -> Tuple[int, int]: + # adopted from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L615 + has_affine = self.weight is not None + if self.training: + flops = input.numel() * (2 if has_affine else 1) + else: + flops = input.numel() * (5 if has_affine else 4) + macs = 0 + return flops, macs + + +try: + import apex + meta_profiler_module.register(apex.normalization.FusedLayerNorm)(torch_nn_normalize) + meta_profiler_module.register(apex.normalization.FusedRMSNorm)(torch_nn_normalize) + meta_profiler_module.register(apex.normalization.MixedFusedLayerNorm)(torch_nn_normalize) + meta_profiler_module.register(apex.normalization.MixedFusedRMSNorm)(torch_nn_normalize) +except (ImportError, AttributeError): + pass diff --git a/colossalai/fx/profiler/profiler_module/pooling.py b/colossalai/fx/profiler/profiler_module/pooling.py new file mode 100644 index 000000000..e429ac3ee --- /dev/null +++ b/colossalai/fx/profiler/profiler_module/pooling.py @@ -0,0 +1,22 @@ +from typing import Tuple +import torch +from ..registry import meta_profiler_module + + +@meta_profiler_module.register(torch.nn.AvgPool1d) +@meta_profiler_module.register(torch.nn.AvgPool2d) +@meta_profiler_module.register(torch.nn.AvgPool3d) +@meta_profiler_module.register(torch.nn.MaxPool1d) +@meta_profiler_module.register(torch.nn.MaxPool2d) +@meta_profiler_module.register(torch.nn.MaxPool3d) +@meta_profiler_module.register(torch.nn.AdaptiveAvgPool1d) +@meta_profiler_module.register(torch.nn.AdaptiveMaxPool1d) +@meta_profiler_module.register(torch.nn.AdaptiveAvgPool2d) +@meta_profiler_module.register(torch.nn.AdaptiveMaxPool2d) +@meta_profiler_module.register(torch.nn.AdaptiveAvgPool3d) +@meta_profiler_module.register(torch.nn.AdaptiveMaxPool3d) +def torch_nn_pooling(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]: + # all pooling could be considered as going over each input element only once (https://stackoverflow.com/a/67301217) + flops = input.numel() + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/profiler_module/rnn.py b/colossalai/fx/profiler/profiler_module/rnn.py new file mode 100644 index 000000000..c042458b3 --- /dev/null +++ b/colossalai/fx/profiler/profiler_module/rnn.py @@ -0,0 +1,13 @@ +import torch +from ..registry import meta_profiler_module +from typing import Optional, Tuple + + +# TODO: calculate rnn FLOPs +@meta_profiler_module.register(torch.nn.GRU) +@meta_profiler_module.register(torch.nn.RNN) +def torch_nn_rnn(self: torch.nn.Module, input: torch.Tensor, hx: torch.Tensor) -> Tuple[int, int]: + raise NotImplementedError + flops = 0 + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/registry.py b/colossalai/fx/profiler/registry.py new file mode 100644 index 000000000..7d73bce32 --- /dev/null +++ b/colossalai/fx/profiler/registry.py @@ -0,0 +1,25 @@ +class ProfilerRegistry: + + def __init__(self, name): + self.name = name + self.store = {} + + def register(self, source): + + def wrapper(func): + self.store[source] = func + return func + + return wrapper + + def get(self, source): + assert source in self.store + target = self.store[source] + return target + + def has(self, source): + return source in self.store + + +meta_profiler_function = ProfilerRegistry(name='patched_functions_for_meta_profile') +meta_profiler_module = ProfilerRegistry(name='patched_modules_for_meta_profile') diff --git a/colossalai/fx/profiler/utils.py b/colossalai/fx/profiler/utils.py new file mode 100644 index 000000000..5024acb50 --- /dev/null +++ b/colossalai/fx/profiler/utils.py @@ -0,0 +1,180 @@ +from functools import partial +from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos +from typing import Callable, NamedTuple, Any, Dict, Tuple +import torch +from torch.fx.node import Argument, Target +from torch.fx._compatibility import compatibility +from colossalai.fx.tracer.meta_patch import meta_patched_function, meta_patched_module +from . import meta_profiler_function, meta_profiler_module + +__all__ = [ + 'MetaProfile', 'profile_function', 'profile_module', 'profile_method', 'calculate_activation_size', + 'calculate_param_size' +] + +# TODO fill out the inplace ops +INPLACE_OPS = [ + add, + sub, + mul, + floordiv, + neg, + pos, + getitem, + setitem, + torch.Tensor.cpu, +] + +# TODO check that call_methods are indeed inplace +INPLACE_METHOD = [ + 'transpose', + 'permute', +] + + +@compatibility(is_backward_compatible=True) +class MetaProfile(NamedTuple): + # MetaProfile is a structure containing pertinent information + # about a node within a torch.fx GraphModule. + + param: int + activation: int + flops: int + macs: int + + +def calculate_activation_size(activation: any) -> int: + """ + Calculate activation size of a node. + """ + activation_size = 0 + if isinstance(activation, torch.Tensor): + activation_size += activation.numel() * torch.tensor([], dtype=activation.dtype).element_size() + elif isinstance(activation, dict): + value_list = [v for _, v in activation.items()] + activation_size += calculate_activation_size(value_list) + else: + for element in activation: + activation_size += calculate_activation_size(element) + return activation_size + + +def calculate_param_size(mod: torch.nn.Module) -> int: + """ + Calculate param size of a node. + """ + param_size = 0 + for param in mod.parameters(): + param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size() + return param_size + + +def profile_function(target: 'Target') -> Callable: + """ + Wrap a `call_function` node or `torch.nn.functional` in order to + record the memory cost and FLOPs of the execution. + + Warnings: + You may only use tensors with `device=meta` for this wrapped function. + Only original `torch.nn.functional` are available. + + Usage: + input = torch.rand(100, 100, 100, 100, device='meta') + func = torch.nn.functional.relu + output, profile = profile_function(func)(input, inplace=False) + print(f"Profiling function {func},") + print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs") + """ + + def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + assert meta_profiler_function.has(target) or meta_profiler_function.has( + target.__name__), f"Colossal-AI hasn't supported profiling for {target}, you might manually patch it." + + # call_function has no parameters + param_size = 0 + activation_size = 0 + result = func(*args, **kwargs) + if target not in INPLACE_OPS and not kwargs.get('inplace', False): + activation_size += calculate_activation_size(result) + if meta_profiler_function.has(target): + profiler = meta_profiler_function.get(target) + else: + profiler = meta_profiler_function.get(target.__name__) + flops, macs = profiler(*args, **kwargs) + return result, MetaProfile(param_size, activation_size, flops, macs) + + f.__name__ = target.__name__ + # fetch patched function + if meta_patched_function.has(target): + func = meta_patched_function.get(target) + elif meta_patched_function.has(target.__name__): + func = meta_patched_function.get(target.__name__) + else: + func = target + return f + + +def profile_method(target: 'Target') -> Callable: + """ + Wrap a `call_method` node + record the memory cost and FLOPs of the execution. + + Warnings: + This is not fully implemented and you may follow the error message to debug. + """ + + def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + # args[0] is the `self` object for this method call + self_obj, *args_tail = args + + # Execute the method and return the result + assert isinstance(target, str), f'{target} instance is not str.' + result = getattr(self_obj, target)(*args_tail, **kwargs) + assert target in INPLACE_METHOD, f'Please check {target} is an inplace method. If so, add target to INPLACE_METHOD={INPLACE_METHOD}.' + + # call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs. + param_size = 0 + activation_size = 0 + flops = 0 + macs = 0 + return result, MetaProfile(param_size, activation_size, flops, macs) + + return f + + +def profile_module(module: torch.nn.Module) -> Callable: + """ + Wrap a `call_module` node or `torch.nn` in order to + record the memory cost and FLOPs of the execution. + + Warnings: + You may only use tensors with `device=meta` for this wrapped function. + Only original `torch.nn` are available. + + Usage: + input = torch.rand(4, 3, 224, 224, device='meta') + mod = torch.nn.Conv2d(3, 128, 3) + output, profile = profile_module(mod)(input) + print(f"Profiling function {mod},") + print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs") + """ + + def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: + assert meta_profiler_module.has( + type(module)), f"Colossal-AI hasn't supported profiling for {module}, you might manually patch it." + param_size = calculate_param_size(module) + activation_size = 0 + result = func(*args, **kwargs) + if not getattr(module, 'inplace', False): + activation_size += calculate_activation_size(result) + profiler = meta_profiler_module.get(type(module)) + flops, macs = profiler(module, *args, **kwargs) + return result, MetaProfile(param_size, activation_size, flops, macs) + + f.__name__ = module.__class__.__name__ + # fetch patched module + if meta_patched_module.has(type(module)): + func = partial(meta_patched_module.get(type(module)), module) + else: + func = module.forward + return f diff --git a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py index b534b84b2..31e54db36 100644 --- a/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py +++ b/tests/test_fx/test_ckpt_solvers/test_ckpt_torchvision.py @@ -68,7 +68,7 @@ def _run_ckpt_solver(rank): tracer = ColoTracer(trace_act_ckpt=False) - data = torch.rand(2, 3, 32, 32) + data = torch.rand(2, 3, 32, 32, device='meta') for solver in SOLVERS: for model_cls in MODEL_LIST: m = model_cls(num_classes=5) @@ -98,7 +98,7 @@ def _run_ckpt_solver_torch11(rank): tracer = ColoTracer(trace_act_ckpt=False) - data = torch.rand(2, 3, 32, 32) + data = torch.rand(2, 3, 32, 32, device='meta') for solver in SOLVERS: for model_cls in MODEL_LIST: m = model_cls(num_classes=5) diff --git a/tests/test_fx/test_comm_size_compute.py b/tests/test_fx/test_comm_size_compute.py index c3bd78002..69fb6ca95 100644 --- a/tests/test_fx/test_comm_size_compute.py +++ b/tests/test_fx/test_comm_size_compute.py @@ -32,7 +32,7 @@ class MLP(torch.nn.Module): def test_comm_size_compute(): model = MLP(MODEL_DIM) - input_sample = torch.rand(BATCH_SIZE, MODEL_DIM) + input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta') gm = symbolic_trace(model) MetaInfoProp(gm).run(input_sample) annotated_model = uniform_split_pass(gm, PIPELINE_SIZE) diff --git a/tests/test_fx/test_meta_info_prop.py b/tests/test_fx/test_meta_info_prop.py index 1da4f6b3b..ae827bf4f 100644 --- a/tests/test_fx/test_meta_info_prop.py +++ b/tests/test_fx/test_meta_info_prop.py @@ -20,17 +20,20 @@ def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor): def test_meta_info_prop(): model = torch.nn.Linear(DIM_IN, DIM_OUT) - input_sample = torch.rand(BATCH_SIZE, DIM_IN) + input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta') orig_output = model(input_sample) gm = symbolic_trace(model) for node in gm.graph.nodes: assert not hasattr(node, 'node_size'), 'The attribute Node.node_size should not exist before MetaInfoProp procedure' assert not hasattr(node, - 'param_size'), 'The attribute Node.param_size should not exist before MetaInfoProp procedure' + '__param__'), 'The attribute Node.__param__ should not exist before MetaInfoProp procedure' assert not hasattr( - node, - 'activation_size'), 'The attribute Node.activation_size should not exist before MetaInfoProp procedure' + node, '__activation__'), 'The attribute Node.__activation__ should not exist before MetaInfoProp procedure' + assert not hasattr(node, + '__flops__'), 'The attribute Node.__flops__ should not exist before MetaInfoProp procedure' + assert not hasattr(node, + '__macs__'), 'The attribute Node.__macs__ should not exist before MetaInfoProp procedure' MetaInfoProp(gm).run(input_sample) for node in gm.graph.nodes: if node.op == 'placeholder': @@ -38,9 +41,11 @@ def test_meta_info_prop(): if node.op == 'output': meta_check(node.meta['tensor_meta'], orig_output) assert hasattr(node, 'node_size'), 'The attribute Node.node_size should exist after MetaInfoProp procedure' - assert hasattr(node, 'param_size'), 'The attribute Node.param_size should exist after MetaInfoProp procedure' - assert hasattr( - node, 'activation_size'), 'The attribute Node.activation_size should exist after MetaInfoProp procedure' + assert hasattr(node, '__param__'), 'The attribute Node.__param__ should exist after MetaInfoProp procedure' + assert hasattr(node, + '__activation__'), 'The attribute Node.__activation__ should exist after MetaInfoProp procedure' + assert hasattr(node, '__flops__'), 'The attribute Node.__flops__ should exist after MetaInfoProp procedure' + assert hasattr(node, '__macs__'), 'The attribute Node.__macs__ should exist after MetaInfoProp procedure' if __name__ == '__main__':