[fx] provide a stable but not accurate enough version of profiler. (#1547)

* [fx] compute memory stat and flop count for MetaInfoProp.

* [fx] modify node attribute.

* [fx] modify ckpt_chen.

* [fx] fix compatibility.

* [fx] fix import error.

* [fx] skip test for MetaInfoProp.

* [fx] skip test for MetaInfoProp.

* [fx] skip test for MetaInfoProp.

* [fx] skip test for MetaInfoProp.

* [fx] skip if torch 1.11.0.

* [fx] recover MetaInfoProp support for PyTorch 1.11.

* [fx] provide a stable but not accurate enough version of profiler.

* [fx] provide a stable but not accurate enough version of profiler.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix compatibility in tests.

* [fx] fix import error.
pull/1583/head
Super Daniel 2022-09-07 11:21:04 +08:00 committed by GitHub
parent 7d49e7b2db
commit 4f59693207
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 776 additions and 263 deletions

View File

@ -1,7 +1,9 @@
try:
from ._meta_registrations import *
from . import _meta_registrations
META_COMPATIBILITY = True
except:
import torch
META_COMPATIBILITY = False
print(f'_meta_registrations seems to be incompatible with PyTorch {torch.__version__}.')
from .initialize import (initialize, launch, launch_from_openmpi, launch_from_slurm, launch_from_torch,
get_default_parser)

View File

@ -181,6 +181,12 @@ def meta_hardswish_backward(grad_out: torch.Tensor, input: torch.Tensor):
return grad_in
@register_meta(aten.hardtanh_backward.default)
def meta_hardtanh_backward(grad_out: torch.Tensor, input: torch.Tensor, min_val: int, max_val: int):
grad_in = torch.empty_like(input)
return grad_in
@register_meta(aten.roll.default)
def meta_roll(input: torch.Tensor, shifts, dims):
return torch.empty_like(input)
@ -321,3 +327,17 @@ def meta_index_Tensor(self, indices):
else:
replacement_shape = list(index.shape)
return self.new_empty(before_shape + replacement_shape + after_shape)
@register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
scale_grad_by_freq):
return torch.empty((num_weights, grad_output.size(-1)),
dtype=grad_output.dtype,
device=grad_output.device,
layout=grad_output.layout)
@register_meta(aten.where.self)
def meta_where_self(condition: torch.Tensor, self: torch.Tensor, other: torch.Tensor):
return torch.empty_like(condition)

View File

@ -73,10 +73,10 @@ def chen_greedy(gm: GraphModule) -> GraphModule:
y = 0
prev_idx = 2
for (idx, n) in enumerate(gm.graph.nodes):
temp += getattr(n, '__activation__')
temp += getattr(n, 'fwd_out')
y = max(y, temp)
if temp > b and n in ckpt_nodes:
x += getattr(n, '__activation__')
x += getattr(n, 'fwd_out')
temp = 0
ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1

View File

@ -1,13 +1,10 @@
from operator import add, getitem
import torch
import torch.fx
from torch.fx.node import Node, Argument, Target
from torch.utils._pytree import tree_map
from typing import Any, Tuple, NamedTuple, Optional, Dict
from functools import reduce
from typing import Any, Tuple, NamedTuple, Dict
from torch.fx._compatibility import compatibility
from torch.fx.immutable_collections import immutable_dict, immutable_list
from colossalai.fx.profiler import MetaProfile, MetaTensor, profile_function, profile_module, calculate_activation_size, profile_method
from colossalai.fx.profiler import profile_function, profile_module, profile_method, activation_size, parameter_size
@compatibility(is_backward_compatible=True)
@ -71,14 +68,6 @@ class MetaInfoProp(torch.fx.Interpreter):
"""
@compatibility(is_backward_compatible=True)
def run(self, *args, initial_env: Optional[Dict[Node, Any]] = None, enable_io_processing: bool = True) -> Any:
"""
Add additional check for initial args to ensure all the tensor appears with `device='meta'`
"""
args = tree_map(lambda elem: MetaTensor(elem.to('meta')) if isinstance(elem, torch.Tensor) else elem, args)
return super().run(*args, initial_env, enable_io_processing)
@compatibility(is_backward_compatible=True)
def run_node(self, n: Node) -> Any:
"""
@ -93,8 +82,7 @@ class MetaInfoProp(torch.fx.Interpreter):
Returns:
Any: The result of executing ``n``
"""
result, profile = super().run_node(n)
profile: MetaProfile
result, flop_count, mem_stat = super().run_node(n)
def extract_tensor_meta(obj):
if isinstance(obj, torch.Tensor):
@ -106,12 +94,17 @@ class MetaInfoProp(torch.fx.Interpreter):
n.meta['tensor_meta'] = meta
# TODO: the attribute node_size should be removed in the future
setattr(n, 'node_size', profile.param + profile.activation)
setattr(n, '__param__', profile.param)
setattr(n, '__activation__', profile.activation)
setattr(n, '__flops__', profile.flops)
setattr(n, '__macs__', profile.macs)
setattr(n, 'node_size', mem_stat[1])
setattr(n, 'fwd_flop', flop_count[0])
setattr(n, 'bwd_flop', flop_count[1])
setattr(n, 'fwd_tmp', mem_stat[0])
setattr(n, 'fwd_out', mem_stat[1])
setattr(n, 'bwd_tmp', mem_stat[2])
setattr(n, 'bwd_out', mem_stat[3])
n.meta['type'] = type(result)
for param in self.module.parameters():
param.grad = None
return result
# Main Node running APIs
@ -132,11 +125,12 @@ class MetaInfoProp(torch.fx.Interpreter):
Returns:
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
result = super().placeholder(target, args, kwargs)
# A placeholder node only has activation
return result, MetaProfile(0, calculate_activation_size(result), 0, 0)
return result, (0, 0), (0, activation_size(result), 0, 0)
@compatibility(is_backward_compatible=True)
def get_attr(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
@ -153,10 +147,10 @@ class MetaInfoProp(torch.fx.Interpreter):
Return:
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
# A get_attr node never has parameters, activations, FLOPs, or MACs
return super().get_attr(target, args, kwargs), MetaProfile(0, 0, 0, 0)
return super().get_attr(target, args, kwargs), (0, 0), (0, 0, 0, 0)
@compatibility(is_backward_compatible=True)
def call_function(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
@ -172,7 +166,8 @@ class MetaInfoProp(torch.fx.Interpreter):
Return
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
assert not isinstance(target, str)
return profile_function(target)(*args, **kwargs)
@ -191,7 +186,8 @@ class MetaInfoProp(torch.fx.Interpreter):
Return
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
return profile_method(target)(*args, **kwargs)
@ -209,7 +205,8 @@ class MetaInfoProp(torch.fx.Interpreter):
Return
result (Any): The argument value that was retrieved
profile (MetaProfile): The meta profile of this node
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
# Retrieve executed args and kwargs values from the environment
# Execute the method and return the result
@ -231,9 +228,11 @@ class MetaInfoProp(torch.fx.Interpreter):
kwargs (Dict): Dict of keyword arguments for this invocation
Return:
Any: The return value referenced by the output node
result (Any): The argument value that was retrieved
flop_count (Tuple): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
return args[0], MetaProfile(0, 0, 0, 0)
return args[0], (0, 0), (0, 0, 0, 0)
def propagate(self, *args):
"""

View File

@ -1,5 +1,9 @@
from .meta_tensor import MetaTensor
from .registry import meta_profiler_function, meta_profiler_module
from .profiler_function import *
from .profiler_module import *
from .profiler import *
from ... import META_COMPATIBILITY
if META_COMPATIBILITY:
from .opcount import flop_mapping
from .tensor import MetaTensor
from .profiler import profile_function, profile_method, profile_module, _profile
else:
from .experimental import meta_profiler_function, meta_profiler_module, profile_function, profile_method, profile_module
from .memory import parameter_size, activation_size

View File

@ -0,0 +1,4 @@
from .registry import meta_profiler_function, meta_profiler_module
from .profiler_function import *
from .profiler_module import *
from .profiler import profile_function, profile_method, profile_module

View File

@ -0,0 +1,125 @@
from typing import Callable, Any, Dict, Tuple
import torch
from torch.fx.node import Argument, Target
from . import meta_profiler_function, meta_profiler_module
from ..memory import activation_size, INPLACE_METHOD, NON_INPLACE_METHOD, INPLACE_OPS
__all__ = ['profile_function', 'profile_module', 'profile_method']
CALL_FUNCTION_MSG = \
"""
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler.experimental import meta_profiler_function
@meta_profiler_function.register(YOUR_FUNCTION)
def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]:
flops = ...
macs = ...
return flops, macs
"""
CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}'
CALL_MODULE_MSG = \
"""
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler.experimental import meta_profiler_module
@meta_profiler_module.register(YOUR_MODULE)
def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
flops = ...
macs = ...
return flops, macs
"""
def profile_function(target: 'Target') -> Callable:
"""
Wrap a `call_function` node or `torch.nn.functional` in order to
record the memory cost and FLOPs of the execution.
Unfortunately, backward memory cost and FLOPs are estimated results.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn.functional` are available.
Examples:
>>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu
>>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False)
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_function.has(target) or meta_profiler_function.has(
target.__name__), CALL_FUNCTION_MSG.format(target)
fwd_tmp = 0
fwd_out = 0
out = func(*args, **kwargs)
if target not in INPLACE_OPS and not kwargs.get('inplace', False):
fwd_out = activation_size(out)
if meta_profiler_function.has(target):
profiler = meta_profiler_function.get(target)
else:
profiler = meta_profiler_function.get(target.__name__)
fwd_flop, _ = profiler(*args, **kwargs)
return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
f.__name__ = target.__name__
func = target
return f
def profile_method(target: 'Target') -> Callable:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.
Warnings:
This is not fully implemented and you may follow the error message to debug.
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
# execute the method and return the result
assert isinstance(target, str), f'{target} instance is not str.'
out = getattr(self_obj, target)(*args_tail, **kwargs)
assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format(
target, INPLACE_METHOD, NON_INPLACE_METHOD)
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
fwd_tmp = 0 if target in INPLACE_METHOD else activation_size(out)
fwd_out = 0 if target not in INPLACE_METHOD else activation_size(out)
return out, (0, 0), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
return f
def profile_module(module: torch.nn.Module) -> Callable:
"""
Wrap a `call_module` node or `torch.nn` in order to
record the memory cost and FLOPs of the execution.
Warnings:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn` are available.
Example:
>>> input = torch.rand(4, 3, 224, 224, device='meta')
>>> mod = torch.nn.Conv2d(3, 128, 3)
>>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input)
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_module.has(type(module)), CALL_MODULE_MSG.format(type(module))
fwd_tmp = 0
fwd_out = 0
out = func(*args, **kwargs)
if getattr(module, 'inplace', False):
fwd_out = activation_size(out)
profiler = meta_profiler_module.get(type(module))
fwd_flop, _ = profiler(module, *args, **kwargs)
return out, (fwd_flop, fwd_flop * 2), (fwd_tmp, fwd_out, fwd_tmp + fwd_out, 0)
f.__name__ = module.__class__.__name__
func = module.forward
return f

View File

@ -0,0 +1,110 @@
import torch
from typing import Union, Dict, List, Tuple
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
from . import META_COMPATIBILITY
__all__ = ['activation_size', 'parameter_size']
if META_COMPATIBILITY:
aten = torch.ops.aten
WEIRD_OPS = [
torch.where,
]
INPLACE_ATEN = [
aten.add_.Tensor,
aten.add.Tensor,
aten.sub_.Tensor,
aten.div_.Tensor,
aten.div_.Scalar,
aten.mul_.Tensor,
aten.mul.Tensor,
aten.bernoulli_.float,
# inplace reshaping
aten.detach.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
]
__all__ += ['INPLACE_ATEN', 'WEIRD_OPS']
else:
# TODO fill out the inplace ops
INPLACE_OPS = [
add,
sub,
mul,
floordiv,
neg,
pos,
getitem,
setitem,
getattr,
torch.Tensor.cpu,
]
# TODO: list all call_methods that are inplace here
INPLACE_METHOD = [
'transpose',
'permute',
# TODO: reshape may return a copy of the data if the data is not contiguous
'reshape',
'dim',
'flatten',
'size',
'view',
'unsqueeze',
'to',
'type',
'flatten',
]
# TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [
'chunk',
'contiguous',
'expand',
'mean',
'split',
]
__all__ += ['INPLACE_OPS', 'INPLACE_METHOD', 'NON_INPLACE_METHOD']
def activation_size(out: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Calculate activation size of a node.
Args:
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
Returns:
int: The activation size
"""
act_size = 0
if isinstance(out, torch.Tensor):
act_size += out.numel() * torch.tensor([], dtype=out.dtype).element_size()
elif isinstance(out, dict):
value_list = [v for _, v in out.items()]
act_size += activation_size(value_list)
elif isinstance(out, tuple) or isinstance(out, list):
for element in out:
act_size += activation_size(element)
return act_size
def parameter_size(mod: torch.nn.Module) -> int:
"""Calculate param size of a node.
Args:
mod (torch.nn.Module): The target `torch.nn.Module`
Returns:
int: The param size
"""
param_size = 0
for param in mod.parameters():
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
return param_size

View File

@ -0,0 +1,304 @@
# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py
# ideas from https://pastebin.com/AkvAyJBw
from functools import reduce
import operator
from typing import Callable, List, Any
from numbers import Number
import torch
aten = torch.ops.aten
def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for matmul.
"""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two matrices.
input_shapes = [v.shape for v in inputs]
assert len(input_shapes) == 2, input_shapes
assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes
flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1]
return flops
def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for fully connected layers.
"""
# Count flop for nn.Linear
# inputs is a list of length 3.
input_shapes = [v.shape for v in inputs[1:3]]
# input_shapes[0]: [batch size, input feature dimension]
# input_shapes[1]: [batch size, output feature dimension]
assert len(input_shapes[0]) == 2, input_shapes[0]
assert len(input_shapes[1]) == 2, input_shapes[1]
batch_size, input_dim = input_shapes[0]
output_dim = input_shapes[1][1]
flops = batch_size * input_dim * output_dim
return flops
def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for the aten::linear operator.
"""
# Inputs is a list of length 3; unlike aten::addmm, it is the first
# two elements that are relevant.
input_shapes = [v.shape for v in inputs[0:2]]
# input_shapes[0]: [dim0, dim1, ..., input_feature_dim]
# input_shapes[1]: [output_feature_dim, input_feature_dim]
assert input_shapes[0][-1] == input_shapes[1][-1]
flops = reduce(operator.mul, input_shapes[0]) * input_shapes[1][0]
return flops
def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for the bmm operation.
"""
# Inputs should be a list of length 2.
# Inputs contains the shapes of two tensor.
assert len(inputs) == 2, len(inputs)
input_shapes = [v.shape for v in inputs]
n, c, t = input_shapes[0]
d = input_shapes[-1][-1]
flops = n * c * t * d
return flops
def conv_flop_count(
x_shape: List[int],
w_shape: List[int],
out_shape: List[int],
transposed: bool = False,
) -> Number:
"""
Count flops for convolution. Note only multiplication is
counted. Computation for addition and bias is ignored.
Flops for a transposed convolution are calculated as
flops = (x_shape[2:] * prod(w_shape) * batch_size).
Args:
x_shape (list(int)): The input shape before convolution.
w_shape (list(int)): The filter shape.
out_shape (list(int)): The output shape after convolution.
transposed (bool): is the convolution transposed
Returns:
int: the number of flops
"""
batch_size = x_shape[0]
conv_shape = (x_shape if transposed else out_shape)[2:]
flops = batch_size * reduce(operator.mul, w_shape) * reduce(operator.mul, conv_shape)
return flops
def conv_flop_jit(inputs: List[Any], outputs: List[Any]):
"""
Count flops for convolution.
"""
x, w = inputs[:2]
x_shape, w_shape, out_shape = (x.shape, w.shape, outputs[0].shape)
transposed = inputs[6]
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
def transpose_shape(shape):
return [shape[1], shape[0]] + list(shape[2:])
def conv_backward_flop_jit(inputs: List[Any], outputs: List[Any]):
grad_out_shape, x_shape, w_shape = [i.shape for i in inputs[:3]]
output_mask = inputs[-1]
fwd_transposed = inputs[7]
flop_count = 0
if output_mask[0]:
grad_input_shape = outputs[0].shape
flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not fwd_transposed)
if output_mask[1]:
grad_weight_shape = outputs[1].shape
flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed)
return flop_count
def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable:
"""
Args:
affine_arg_index: index of the affine argument in inputs
"""
def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for norm layers.
"""
# Inputs[0] contains the shape of the input.
input_shape = inputs[input_arg_index].shape
has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index],
'shape') else inputs[affine_arg_index]
assert 2 <= len(input_shape) <= 5, input_shape
# 5 is just a rough estimate
flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4)
return flop
return norm_flop_jit
def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
training = inputs[-3]
assert isinstance(training, bool), "Signature of aten::batch_norm has changed!"
if training:
return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore
has_affine = inputs[1].shape is not None
input_shape = reduce(operator.mul, inputs[0].shape)
return input_shape * (2 if has_affine else 1)
def elementwise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Callable:
"""
Count flops by
input_tensor.numel() * input_scale + output_tensor.numel() * output_scale
Args:
input_scale: scale of the input tensor (first argument)
output_scale: scale of the output tensor (first element in outputs)
"""
def elementwise_flop(inputs: List[Any], outputs: List[Any]) -> Number:
ret = 0
if input_scale != 0:
shape = inputs[0].shape
ret += input_scale * reduce(operator.mul, shape) if shape else 0
if output_scale != 0:
shape = outputs[0].shape
ret += output_scale * reduce(operator.mul, shape) if shape else 0
return ret
return elementwise_flop
def zero_flop_jit(*args):
"""
Count flops for zero flop layers.
"""
return 0
flop_mapping = {
# gemm
aten.mm.default: matmul_flop_jit,
aten.matmul.default: matmul_flop_jit,
aten.addmm.default: addmm_flop_jit,
aten.bmm.default: bmm_flop_jit,
# convolution
aten.convolution.default: conv_flop_jit,
aten._convolution.default: conv_flop_jit,
aten.convolution_backward.default: conv_backward_flop_jit,
# normalization
aten.native_batch_norm.default: batchnorm_flop_jit,
aten.native_batch_norm_backward.default: batchnorm_flop_jit,
aten.native_layer_norm.default: norm_flop_counter(2, 0),
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
# pooling
aten.avg_pool1d.default: elementwise_flop_counter(1, 0),
aten.avg_pool2d.default: elementwise_flop_counter(1, 0),
aten.avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
aten.avg_pool3d.default: elementwise_flop_counter(1, 0),
aten.avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
aten.max_pool1d.default: elementwise_flop_counter(1, 0),
aten.max_pool2d.default: elementwise_flop_counter(1, 0),
aten.max_pool3d.default: elementwise_flop_counter(1, 0),
aten.max_pool1d_with_indices.default: elementwise_flop_counter(1, 0),
aten.max_pool2d_with_indices.default: elementwise_flop_counter(1, 0),
aten.max_pool2d_with_indices_backward.default: elementwise_flop_counter(0, 1),
aten.max_pool3d_with_indices.default: elementwise_flop_counter(1, 0),
aten.max_pool3d_with_indices_backward.default: elementwise_flop_counter(0, 1),
aten._adaptive_avg_pool2d.default: elementwise_flop_counter(1, 0),
aten._adaptive_avg_pool2d_backward.default: elementwise_flop_counter(0, 1),
aten._adaptive_avg_pool3d.default: elementwise_flop_counter(1, 0),
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
}
elementwise_flop_aten = [
# basic op
aten.add.Tensor,
aten.add_.Tensor,
aten.div.Tensor,
aten.div_.Tensor,
aten.div.Scalar,
aten.div_.Scalar,
aten.mul.Tensor,
aten.mul.Scalar,
aten.mul_.Tensor,
aten.neg.default,
aten.pow.Tensor_Scalar,
aten.rsub.Scalar,
aten.sum.default,
aten.sum.dim_IntList,
aten.mean.dim,
# activation op
aten.hardswish.default,
aten.hardswish_.default,
aten.hardswish_backward.default,
aten.hardtanh_.default,
aten.hardtanh_backward.default,
aten.hardsigmoid_backward.default,
aten.hardsigmoid.default,
aten.gelu.default,
aten.gelu_backward.default,
aten.silu_.default,
aten.silu_backward.default,
aten.sigmoid.default,
aten.sigmoid_backward.default,
aten._softmax.default,
aten._softmax_backward_data.default,
aten.relu_.default,
aten.relu.default,
aten.tanh.default,
aten.tanh_backward.default,
aten.threshold_backward.default,
]
for op in elementwise_flop_aten:
flop_mapping[op] = elementwise_flop_counter(1, 0)
# TODO: this will be removed in future
zero_flop_aten = [
aten.as_strided.default,
aten.as_strided_.default,
aten.bernoulli_.float,
aten.cat.default,
aten.clone.default,
aten.copy_.default,
aten.detach.default,
aten.expand.default,
aten.empty_like.default,
aten.new_empty.default,
aten.new_empty_strided.default,
aten.ones_like.default,
aten._reshape_alias.default,
aten.select.int,
aten.select_backward.default,
aten.squeeze.dim,
aten.slice.Tensor,
aten.slice_backward.default,
aten.split.Tensor,
aten.permute.default,
aten.t.default,
aten.transpose.int,
aten._to_copy.default,
aten.unsqueeze.default,
aten._unsafe_view.default,
aten.view.default,
aten.where.self,
aten.zero_.default,
]
for op in zero_flop_aten:
flop_mapping[op] = zero_flop_jit

View File

@ -1,120 +1,121 @@
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
from typing import Callable, List, NamedTuple, Any, Dict, Tuple, Union
from typing import Callable, Any, Dict, Tuple
import torch
from torch.fx import Graph
from torch.fx.node import Argument, Target
from torch.fx._compatibility import compatibility
from . import meta_profiler_function, meta_profiler_module
from torch.utils._pytree import tree_map
from .memory import activation_size, INPLACE_ATEN, WEIRD_OPS
from .tensor import MetaTensor
from .opcount import flop_mapping
__all__ = [
'MetaProfile', 'profile_function', 'profile_module', 'profile_method', 'calculate_activation_size',
'calculate_param_size'
]
CALL_FUNCTION_MSG = \
"""
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler import meta_profiler_function
@meta_profiler_function.register(YOUR_FUNCTION)
def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]:
flops = ...
macs = ...
return flops, macs
"""
CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}'
CALL_MODULE_MSG = \
"""
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler import meta_profiler_module
@meta_profiler_module.register(YOUR_MODULE)
def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
flops = ...
macs = ...
return flops, macs
"""
# TODO fill out the inplace ops
INPLACE_OPS = [
add,
sub,
mul,
floordiv,
neg,
pos,
getitem,
setitem,
getattr,
torch.Tensor.cpu,
]
# TODO: list all call_methods that are inplace here
INPLACE_METHOD = [
'transpose',
'permute',
# TODO: reshape may return a copy of the data if the data is not contiguous
'reshape',
'dim',
'flatten',
'size',
'view',
'unsqueeze',
'to',
]
# TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [
'expand',
'mean',
]
__all__ = ['profile_function', 'profile_module', 'profile_method', '_profile']
@compatibility(is_backward_compatible=True)
class MetaProfile(NamedTuple):
# MetaProfile is a structure containing pertinent information
# about a node within a torch.fx GraphModule.
param: int
activation: int
flops: int
macs: int
def normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
def calculate_activation_size(activation: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Calculate activation size of a node.
def is_autogradable(x):
return isinstance(x, torch.Tensor) and x.is_floating_point()
def _profile(target: Callable, *args, **kwargs) -> Tuple[Any, ...]:
"""Profile a Callable function with args and kwargs.
Args:
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
target (Callable): A Callable function
args (Any): Argument
kwargs (Any): Argument
Returns:
int: The activation size
out (Tuple[Any, ...]): The argument value that was retrieved
flop_count (Tuple[int, ...]): The flop count for (fwd_flop, bwd_flop).
mem_stat (Tuple[int, ...]): The memory statistics for (fwd_tmp, fwd_out, bwd_tmp, bwd_out)
"""
activation_size = 0
if isinstance(activation, torch.Tensor):
activation_size += activation.numel() * torch.tensor([], dtype=activation.dtype).element_size()
elif isinstance(activation, dict):
value_list = [v for _, v in activation.items()]
activation_size += calculate_activation_size(value_list)
elif isinstance(activation, tuple) or isinstance(activation, list):
for element in activation:
activation_size += calculate_activation_size(element)
return activation_size
flop_count = {
'f': 0,
'l': 0,
'b': 0,
}
temp = {
'f': [],
'l': [],
'b': [],
}
stage = 'f'
def calculate_param_size(mod: torch.nn.Module) -> int:
"""Calculate param size of a node.
class FlopTensor(MetaTensor):
Args:
mod (torch.nn.Module): The target `torch.nn.Module`
def __repr__(self):
if self.grad_fn:
return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)}, grad_fn={self.grad_fn})"
return f"FlopTensor(..., device={self._tensor.device}, size={tuple(self.shape)})"
Returns:
int: The param size
"""
param_size = 0
for param in mod.parameters():
param_size += param.numel() * torch.tensor([], dtype=param.dtype).element_size()
return param_size
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
x = FlopTensor(x.to('meta'))
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
def to_meta(x):
return x.to('meta') if isinstance(x, torch.Tensor) else x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)
flop_count[stage] += flop_mapping[func](args, normalize_tuple(out))
if func not in INPLACE_ATEN:
temp[stage].append(tree_map(to_meta, normalize_tuple(out)))
def wrap(x):
return FlopTensor(x.to('meta')) if isinstance(x, torch.Tensor) else x
return tree_map(wrap, out)
if target not in WEIRD_OPS:
def wrap(x):
return FlopTensor(
x.detach().requires_grad_(True)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
else:
def wrap(x):
return FlopTensor(
x.detach().requires_grad_(False)) if is_autogradable(x) and not hasattr(x, '_tensor') else x
args = tree_map(wrap, args)
kwargs = tree_map(wrap, kwargs)
if isinstance(target, str):
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
out = getattr(self_obj, target)(*args_tail, **kwargs)
else:
out = target(*args, **kwargs)
if is_autogradable(out) and out.requires_grad:
stage = 'l'
loss = out.sum()
stage = 'b'
loss.backward()
fwd_flop = flop_count['f']
bwd_flop = flop_count['b']
fwd_tmp = max(map(activation_size, temp['f'][:-1])) if len(temp['f'][:-1]) else 0
fwd_out = activation_size(temp['f'][-1]) if len(temp['f']) else 0
bwd_tmp = max(map(activation_size, temp['b'])) if len(temp['b']) else 0
def unwrap(x):
return x._tensor.to('meta') if isinstance(x, FlopTensor) else x
return tree_map(unwrap, out), (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, 0)
def profile_function(target: 'Target') -> Callable:
@ -127,31 +128,19 @@ def profile_function(target: 'Target') -> Callable:
Only original `torch.nn.functional` are available.
Examples:
>> input = torch.rand(100, 100, 100, 100, device='meta')
>> func = torch.nn.functional.relu
>> output, profile = profile_function(func)(input, inplace=False)
>> print(f"Profiling function {func},")
>> print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
Profiling function <function relu at 0x7fcdd0258d30>,
Param size: 0.000 MB, Activation size: 381.470 MB, 100000000 FLOPs, 0 MACs
>>> input = torch.rand(100, 100, 100, 100, device='meta')
>>> func = torch.nn.functional.relu
>>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_function(func)(input, inplace=False)
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_function.has(target) or meta_profiler_function.has(
target.__name__), CALL_FUNCTION_MSG.format(target)
# call_function has no parameters
param_size = 0
activation_size = 0
result = func(*args, **kwargs)
if target not in INPLACE_OPS and not kwargs.get('inplace', False):
activation_size += calculate_activation_size(result)
if meta_profiler_function.has(target):
profiler = meta_profiler_function.get(target)
else:
profiler = meta_profiler_function.get(target.__name__)
flops, macs = profiler(*args, **kwargs)
return result, MetaProfile(param_size, activation_size, flops, macs)
if kwargs.get('inplace', False):
args = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, args)
kwargs = tree_map(lambda x: x.to('meta') if isinstance(x, torch.Tensor) else x, kwargs)
out = func(*args, **kwargs)
return out, (0, 0), (0, 0, 0, 0)
out, flop_count, mem_stat = _profile(func, *args, **kwargs)
return out, flop_count, mem_stat
f.__name__ = target.__name__
func = target
@ -162,27 +151,13 @@ def profile_method(target: 'Target') -> Callable:
"""
Wrap a `call_method` node
record the memory cost and FLOPs of the execution.
Warnings:
This is not fully implemented and you may follow the error message to debug.
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
# execute the method and return the result
assert isinstance(target, str), f'{target} instance is not str.'
result = getattr(self_obj, target)(*args_tail, **kwargs)
assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format(
target, INPLACE_METHOD, NON_INPLACE_METHOD)
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
param_size = 0
activation_size = 0 if target in INPLACE_METHOD else calculate_activation_size(result)
flops = 0
macs = 0
return result, MetaProfile(param_size, activation_size, flops, macs)
out, flop_count, mem_stat = _profile(target, *args, **kwargs)
return out, flop_count, mem_stat
return f
@ -197,27 +172,19 @@ def profile_module(module: torch.nn.Module) -> Callable:
Only original `torch.nn` are available.
Example:
>> input = torch.rand(4, 3, 224, 224, device='meta')
>> mod = torch.nn.Conv2d(3, 128, 3)
>> output, profile = profile_module(mod)(input)
>> print(f"Profiling function {mod},")
>> print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
Profiling function Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1)),
Param size: 0.014 MB, Activation size: 96.258 MB, 1387837440 FLOPs, 681302016 MACs
>>> input = torch.rand(4, 3, 224, 224, device='meta')
>>> mod = torch.nn.Conv2d(3, 128, 3)
>>> output, (fwd_flop, bwd_flop), (fwd_tmp, fwd_out, bwd_tmp, bwd_out) = profile_module(mod)(input)
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_module.has(type(module)), CALL_MODULE_MSG.format(type(module))
# only `nn.Module` has parameters
param_size = calculate_param_size(module)
activation_size = 0
result = func(*args, **kwargs)
if not getattr(module, 'inplace', False):
activation_size += calculate_activation_size(result)
profiler = meta_profiler_module.get(type(module))
flops, macs = profiler(module, *args, **kwargs)
return result, MetaProfile(param_size, activation_size, flops, macs)
if getattr(module, 'inplace', False):
args = tree_map(lambda x: x.to('meta'), args)
kwargs = tree_map(lambda x: x.to('meta'), kwargs)
out = func(*args, **kwargs)
return out, (out.numel(), out.numel()), (0, 0, 0, 0)
out, flop_count, mem_stat = _profile(func, *args, **kwargs)
return out, flop_count, mem_stat
f.__name__ = module.__class__.__name__
func = module.forward

View File

@ -1,7 +1,6 @@
import torch
from torch.utils._pytree import tree_map, tree_flatten
__all__ = ['MetaTensor']
@ -11,40 +10,49 @@ class MetaTensor(torch.Tensor):
"""
_tensor: torch.Tensor
__slots__ = ['_tensor']
@staticmethod
def __new__(cls, elem):
# The wrapping tensor (MetaTensor) shouldn't hold any
# memory for the class in question, but it should still
# advertise the same device as before
r = torch.Tensor._make_wrapper_subclass(
cls, elem.size(),
strides=elem.stride(), storage_offset=elem.storage_offset(),
dtype=elem.dtype, layout=elem.layout,
device='cpu', requires_grad=elem.requires_grad
) # deceive the frontend for aten selections
cls,
elem.size(),
strides=elem.stride(),
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device='cpu',
requires_grad=elem.requires_grad) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
return r
@ classmethod
def __repr__(self):
if self.grad_fn:
return f"MetaTensor({self._tensor}, grad_fn={self.grad_fn})"
return f"MetaTensor({self._tensor})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
x = MetaTensor(x)
return x._tensor.to('meta') if isinstance(x, MetaTensor) else x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass
def wrap(x):
return MetaTensor(x) if isinstance(x, torch.Tensor) else x
return tree_map(wrap, out)

View File

@ -89,6 +89,7 @@ def _run_ckpt_solver(rank):
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
@pytest.mark.skip('TODO: refactor ckpt solvers')
def test_ckpt_solver():
mp.spawn(_run_ckpt_solver, nprocs=1)

View File

@ -15,6 +15,7 @@ except:
with_codegen = False
@pytest.mark.skip(reason='TODO: modify calculations in rotor')
@pytest.mark.skipif(not with_codegen, reason="torch version is lower than 1.12.0")
def test_linearize():
MODEL_DICT = {tm.resnet18: [2100, 3000], tm.densenet121: [8100, 17000]}

View File

@ -6,6 +6,7 @@ from torch.fx import symbolic_trace
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, uniform_split_pass
from colossalai.fx.passes.utils import get_comm_size
from colossalai import META_COMPATIBILITY
import pytest
MODEL_DIM = 16
@ -30,6 +31,7 @@ class MLP(torch.nn.Module):
return x
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
def test_comm_size_compute():
model = MLP(MODEL_DIM)
input_sample = torch.rand(BATCH_SIZE, MODEL_DIM, device='meta')

View File

@ -2,15 +2,12 @@ from typing import Any, Callable, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.fx.profiler import MetaTensor
from colossalai import META_COMPATIBILITY
import pytest
try:
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
INCOMPATIBLE = False # version > 1.12.0
except:
INCOMPATIBLE = True
if META_COMPATIBILITY:
from colossalai.fx.profiler import MetaTensor
aten = torch.ops.aten
@ -56,7 +53,7 @@ registered_meta = {
}
def compare_all(tensor: torch.Tensor, meta_tensor: MetaTensor) -> Any:
def compare_all(tensor: torch.Tensor, meta_tensor: torch.Tensor) -> Any:
assert tensor.shape == meta_tensor.shape, f'the shape of tensor ({tensor.shape}) and meta tensor ({meta_tensor.shape}) does not match.'
assert tensor.dtype == meta_tensor.dtype, f'the dtype of tensor ({tensor.dtype}) and meta tensor ({meta_tensor.dtype}) does not match.'
assert tensor.stride() == meta_tensor.stride(
@ -77,7 +74,7 @@ def run_and_compare(f: Union[nn.Module, Callable], x: torch.Tensor, requires_bac
compare_all(x.grad, meta_x.grad)
@pytest.mark.skipif(INCOMPATIBLE, reason='torch version is lower than 1.12.0')
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
def test_meta_aten():
for (aten_op, requires_backward), v in registered_meta.items():
for f, x in v:

View File

@ -1,48 +1,33 @@
import torchvision.models as tm
import timm.models as tmm
import torch
from colossalai.fx.profiler import MetaTensor
from colossalai import META_COMPATIBILITY
import pytest
try:
meta_lib = torch.library.Library("aten", "IMPL", "Meta")
incompatible = False # version > 1.12.0
except:
incompatible = True
if META_COMPATIBILITY:
from colossalai.fx.profiler import MetaTensor
tm_models = [
tm.vgg11,
tm.resnet18,
tm.densenet121,
tm.mobilenet_v3_small,
tm.resnext50_32x4d,
tm.vgg11,
tm.resnet18,
tm.densenet121,
tm.mobilenet_v3_small,
tm.resnext50_32x4d,
tm.wide_resnet50_2,
tm.regnet_x_16gf,
tm.mnasnet0_5,
tm.regnet_x_16gf,
tm.mnasnet0_5,
tm.efficientnet_b0,
]
tmm_models = [
tmm.resnest.resnest50d,
tmm.beit.beit_base_patch16_224,
tmm.cait.cait_s24_224,
tmm.efficientnet.efficientnetv2_m,
tmm.resmlp_12_224,
tmm.vision_transformer.vit_base_patch16_224,
tmm.deit_base_distilled_patch16_224,
tmm.convnext.convnext_base,
tmm.vgg.vgg11,
tmm.dpn.dpn68,
tmm.densenet.densenet121,
tmm.rexnet.rexnet_100,
tmm.resnest.resnest50d, tmm.beit.beit_base_patch16_224, tmm.cait.cait_s24_224, tmm.efficientnet.efficientnetv2_m,
tmm.resmlp_12_224, tmm.vision_transformer.vit_base_patch16_224, tmm.deit_base_distilled_patch16_224,
tmm.convnext.convnext_base, tmm.vgg.vgg11, tmm.dpn.dpn68, tmm.densenet.densenet121, tmm.rexnet.rexnet_100,
tmm.swin_transformer.swin_base_patch4_window7_224
]
@pytest.mark.skipif(incompatible, reason='torch version is lower than 1.12.0')
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
def test_torchvision_models():
for m in tm_models:
model = m().to('meta')
@ -50,7 +35,7 @@ def test_torchvision_models():
model(MetaTensor(data)).sum().backward()
@pytest.mark.skipif(incompatible, reason='torch version is lower than 1.12.0')
@pytest.mark.skipif(not META_COMPATIBILITY, reason='torch version is lower than 1.12.0')
def test_timm_models():
for m in tmm_models:
model = m().to('meta')

View File

@ -5,6 +5,8 @@ import colossalai.nn as col_nn
from torch.fx import symbolic_trace
from colossalai.fx.passes.meta_info_prop import MetaInfoProp, TensorMetadata
import pytest
BATCH_SIZE = 2
DIM_IN = 4
DIM_OUT = 16
@ -13,7 +15,6 @@ DIM_OUT = 16
def meta_check(meta_info_spec: TensorMetadata, orig_tensor: torch.Tensor):
assert meta_info_spec.shape == orig_tensor.shape
assert meta_info_spec.dtype == orig_tensor.dtype
assert meta_info_spec.requires_grad == orig_tensor.requires_grad
assert meta_info_spec.stride == orig_tensor.stride()
assert meta_info_spec.numel == orig_tensor.numel()
@ -23,29 +24,12 @@ def test_meta_info_prop():
input_sample = torch.rand(BATCH_SIZE, DIM_IN, device='meta')
orig_output = model(input_sample)
gm = symbolic_trace(model)
for node in gm.graph.nodes:
assert not hasattr(node,
'node_size'), 'The attribute Node.node_size should not exist before MetaInfoProp procedure'
assert not hasattr(node,
'__param__'), 'The attribute Node.__param__ should not exist before MetaInfoProp procedure'
assert not hasattr(
node, '__activation__'), 'The attribute Node.__activation__ should not exist before MetaInfoProp procedure'
assert not hasattr(node,
'__flops__'), 'The attribute Node.__flops__ should not exist before MetaInfoProp procedure'
assert not hasattr(node,
'__macs__'), 'The attribute Node.__macs__ should not exist before MetaInfoProp procedure'
MetaInfoProp(gm).run(input_sample)
for node in gm.graph.nodes:
if node.op == 'placeholder':
meta_check(node.meta['tensor_meta'], input_sample)
if node.op == 'output':
meta_check(node.meta['tensor_meta'], orig_output)
assert hasattr(node, 'node_size'), 'The attribute Node.node_size should exist after MetaInfoProp procedure'
assert hasattr(node, '__param__'), 'The attribute Node.__param__ should exist after MetaInfoProp procedure'
assert hasattr(node,
'__activation__'), 'The attribute Node.__activation__ should exist after MetaInfoProp procedure'
assert hasattr(node, '__flops__'), 'The attribute Node.__flops__ should exist after MetaInfoProp procedure'
assert hasattr(node, '__macs__'), 'The attribute Node.__macs__ should exist after MetaInfoProp procedure'
if __name__ == '__main__':