[fx] add more op patches for profiler and error message for unsupported ops. (#1495)

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] merge development into main (#1)

* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.

* [fx] simplify test for ckpt.

* [fx] add rules to linearize computation graphs for searching. (#2)

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] merge development into main (#1)

* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.

* [fx] simplify test for ckpt.

* [fx] fix test and algorithm bugs in activation checkpointing.

* [fx] polish ckpt_test.

* [fx] add rules to linearize computation graphs for searching.

* [fx] remove chen_sqrt for sake of simplicity

* [fx] remove chen_sqrt for sake of simplicity

* [fx] remove chen_sqrt for sake of simplicity

* [fx] remove chen_sqrt for sake of simplicity

* [fx] fix inconsistencies.

* [fx] fix MetaInfoProp.

* [fx] fix MetaInfoProp.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] consider MetaInfoProp for inplace operands.

* [fx] add profiler for fx nodes.

* [fx] add profiler for fx nodes.

* [fx] add profiler for fx nodes.

* [fx] add profiler for fx nodes.

* [fx] add profiler for fx nodes.

* [fx] add profiler for fx nodes.

* [fx] add profiler for fx nodes.

* [fx] fix error in tests.

* [fx] unfix bug.

* [fx] unfix bug.

* [fx] patch more modules and functions.

* [fx] change name of utils.py to profiler.py

* [fx] add profiler for rnn.

* [fx] add profiler for rnn.

* [fx] polish and add more patch for profiler.

* [fx] polish and add more patch for profiler.
pull/1499/head
Super Daniel 2022-08-25 23:11:13 +08:00 committed by GitHub
parent 413c053453
commit 09c023bee2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 310 additions and 77 deletions

View File

@ -1,4 +1,4 @@
from .registry import *
from .profiler_function import *
from .profiler_module import *
from .utils import *
from .profiler import *

View File

@ -1,8 +1,8 @@
from functools import partial
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
from typing import Callable, NamedTuple, Any, Dict, Tuple
from typing import Callable, List, NamedTuple, Any, Dict, Tuple, Union
import torch
from torch.fx.node import Argument, Target
from torch.fx.node import Argument, Target, map_aggregate
from torch.fx._compatibility import compatibility
from colossalai.fx.tracer.meta_patch import meta_patched_function, meta_patched_module
from . import meta_profiler_function, meta_profiler_module
@ -12,6 +12,30 @@ __all__ = [
'calculate_param_size'
]
CALL_FUNCTION_MSG = \
"""
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler import meta_profiler_function
@meta_profiler_function.register(YOUR_FUNCTION)
def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]:
flops = ...
macs = ...
return flops, macs
"""
CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}'
CALL_MODULE_MSG = \
"""
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
from colossalai.fx.profiler import meta_profiler_module
@meta_profiler_module.register(YOUR_MODULE)
def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
flops = ...
macs = ...
return flops, macs
"""
# TODO fill out the inplace ops
INPLACE_OPS = [
add,
@ -22,18 +46,30 @@ INPLACE_OPS = [
pos,
getitem,
setitem,
getattr,
torch.Tensor.cpu,
]
# TODO check that call_methods are indeed inplace
# TODO: list all call_methods that are inplace here
INPLACE_METHOD = [
'transpose',
'permute',
# TODO: reshape may return a copy of the data if the data is not contiguous
'reshape',
'dim',
'flatten',
]
# TODO: list all call_methods that are not inplace here
NON_INPLACE_METHOD = [
'expand',
'mean',
]
@compatibility(is_backward_compatible=True)
class MetaProfile(NamedTuple):
# MetaProfile is a structure containing pertinent information
# about a node within a torch.fx GraphModule.
@ -43,9 +79,14 @@ class MetaProfile(NamedTuple):
macs: int
def calculate_activation_size(activation: any) -> int:
"""
Calculate activation size of a node.
def calculate_activation_size(activation: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Calculate activation size of a node.
Args:
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
Returns:
int: The activation size
"""
activation_size = 0
if isinstance(activation, torch.Tensor):
@ -53,15 +94,20 @@ def calculate_activation_size(activation: any) -> int:
elif isinstance(activation, dict):
value_list = [v for _, v in activation.items()]
activation_size += calculate_activation_size(value_list)
else:
elif isinstance(activation, tuple) or isinstance(activation, list):
for element in activation:
activation_size += calculate_activation_size(element)
return activation_size
def calculate_param_size(mod: torch.nn.Module) -> int:
"""
Calculate param size of a node.
"""Calculate param size of a node.
Args:
mod (torch.nn.Module): The target `torch.nn.Module`
Returns:
int: The param size
"""
param_size = 0
for param in mod.parameters():
@ -78,17 +124,21 @@ def profile_function(target: 'Target') -> Callable:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn.functional` are available.
Usage:
input = torch.rand(100, 100, 100, 100, device='meta')
func = torch.nn.functional.relu
output, profile = profile_function(func)(input, inplace=False)
print(f"Profiling function {func},")
print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
Examples:
>> input = torch.rand(100, 100, 100, 100, device='meta')
>> func = torch.nn.functional.relu
>> output, profile = profile_function(func)(input, inplace=False)
>> print(f"Profiling function {func},")
>> print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
Profiling function <function relu at 0x7fcdd0258d30>,
Param size: 0.000 MB, Activation size: 381.470 MB, 100000000 FLOPs, 0 MACs
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_function.has(target) or meta_profiler_function.has(
target.__name__), f"Colossal-AI hasn't supported profiling for {target}, you might manually patch it."
target.__name__), CALL_FUNCTION_MSG.format(target)
# ensure all arguments satisfy `device='meta'`
args, kwargs = map_aggregate([args, kwargs], lambda a: a.to('meta') if isinstance(a, torch.Tensor) else a)
# call_function has no parameters
param_size = 0
@ -127,14 +177,17 @@ def profile_method(target: 'Target') -> Callable:
# args[0] is the `self` object for this method call
self_obj, *args_tail = args
# Execute the method and return the result
# execute the method and return the result
assert isinstance(target, str), f'{target} instance is not str.'
result = getattr(self_obj, target)(*args_tail, **kwargs)
assert target in INPLACE_METHOD, f'Please check {target} is an inplace method. If so, add target to INPLACE_METHOD={INPLACE_METHOD}.'
# ensure all arguments satisfy `device='meta'`
map_aggregate([args, kwargs], lambda a: a.to('meta') if isinstance(a, torch.Tensor) else a)
result = getattr(self_obj, target)(*args_tail, **kwargs)
assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format(
target, INPLACE_METHOD, NON_INPLACE_METHOD)
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
param_size = 0
activation_size = 0
activation_size = 0 if target in INPLACE_METHOD else calculate_activation_size(result)
flops = 0
macs = 0
return result, MetaProfile(param_size, activation_size, flops, macs)
@ -151,17 +204,20 @@ def profile_module(module: torch.nn.Module) -> Callable:
You may only use tensors with `device=meta` for this wrapped function.
Only original `torch.nn` are available.
Usage:
input = torch.rand(4, 3, 224, 224, device='meta')
mod = torch.nn.Conv2d(3, 128, 3)
output, profile = profile_module(mod)(input)
print(f"Profiling function {mod},")
print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
Example:
>> input = torch.rand(4, 3, 224, 224, device='meta')
>> mod = torch.nn.Conv2d(3, 128, 3)
>> output, profile = profile_module(mod)(input)
>> print(f"Profiling function {mod},")
>> print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
Profiling function Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1)),
Param size: 0.014 MB, Activation size: 96.258 MB, 1387837440 FLOPs, 681302016 MACs
"""
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
assert meta_profiler_module.has(
type(module)), f"Colossal-AI hasn't supported profiling for {module}, you might manually patch it."
assert meta_profiler_module.has(type(module)), CALL_MODULE_MSG.format(type(module))
# ensure all arguments satisfy `device='meta'`
map_aggregate([args, kwargs], lambda a: a.to('meta') if isinstance(a, torch.Tensor) else a)
param_size = calculate_param_size(module)
activation_size = 0
result = func(*args, **kwargs)

View File

@ -12,6 +12,8 @@ _multiplier = {
torch.nn.functional.elu: 4,
torch.nn.functional.relu6: 2,
torch.nn.functional.gelu: 9,
torch.nn.functional.hardswish: 5,
torch.nn.functional.hardsigmoid: 4,
}
@ -23,6 +25,8 @@ _multiplier = {
@meta_profiler_function.register(torch.nn.functional.relu)
@meta_profiler_function.register(torch.nn.functional.sigmoid)
@meta_profiler_function.register(torch.nn.functional.tanh)
@meta_profiler_function.register(torch.nn.functional.hardswish)
@meta_profiler_function.register(torch.nn.functional.hardsigmoid)
def torch_nn_func_non_linear_act(input: torch.Tensor, inplace: bool = False) -> Tuple[int, int]:
flops = input.numel()
macs = 0

View File

@ -1,24 +1,19 @@
import operator
from functools import reduce
from typing import Any, Optional, Tuple, Union
import torch
from ..registry import meta_profiler_function
def _prod(dims):
p = 1
for v in dims:
p *= v
return p
def _elementwise_flops_compute(input, other):
# copied from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L763
if not torch.is_tensor(input):
if torch.is_tensor(other):
return _prod(other.shape), 0
return reduce(operator.mul, other.shape), 0
else:
return 1, 0
elif not torch.is_tensor(other):
return _prod(input.shape), 0
return reduce(operator.mul, input.shape), 0
else:
dim_input = len(input.shape)
dim_other = len(other.shape)
@ -32,17 +27,24 @@ def _elementwise_flops_compute(input, other):
final_shape.append(in_i)
else:
final_shape.append(ot_i)
flops = _prod(final_shape)
flops = reduce(operator.mul, final_shape)
return flops, 0
@meta_profiler_function.register(torch.add)
@meta_profiler_function.register(torch.eq)
@meta_profiler_function.register(torch.sub)
@meta_profiler_function.register(torch.mul)
@meta_profiler_function.register(torch.floor_divide)
@meta_profiler_function.register('add') # for built-in op +
@meta_profiler_function.register('iadd') # for built-in op +=
@meta_profiler_function.register('eq') # for built-in op =
@meta_profiler_function.register('sub') # for built-in op -
@meta_profiler_function.register('isub') # for built-in op -=
@meta_profiler_function.register('mul') # for built-in op *
@meta_profiler_function.register('imul') # for built-in op *=
@meta_profiler_function.register('floordiv') # for built-in op //
@meta_profiler_function.register('ifloordiv') # for built-in op //=
def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
return _elementwise_flops_compute(input, other)
@ -58,14 +60,14 @@ def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = N
@meta_profiler_function.register('matmul') # for built-in op @
@meta_profiler_function.register(torch.Tensor.matmul)
def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = _prod(input.shape) * other.shape[-1]
macs = reduce(operator.mul, input.shape) * other.shape[-1]
flops = 2 * macs
return flops, macs
@meta_profiler_function.register(torch.bmm)
def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
macs = _prod(input.shape) * other.shape[-1]
macs = reduce(operator.mul, input.shape) * other.shape[-1]
flops = 2 * macs
return flops, macs

View File

@ -10,3 +10,10 @@ def operator_getitem(a: Any, b: Any) -> Tuple[int, int]:
flops = 0
macs = 0
return flops, macs
@meta_profiler_function.register(getattr)
def python_getattr(a: Any, b: Any) -> Tuple[int, int]:
flops = 0
macs = 0
return flops, macs

View File

@ -1,15 +1,10 @@
from functools import reduce
import operator
from typing import Any, Optional, Tuple
import torch
from ..registry import meta_profiler_function
def _prod(dims):
p = 1
for v in dims:
p *= v
return p
@meta_profiler_function.register(torch.arange)
@meta_profiler_function.register(torch.finfo)
@meta_profiler_function.register(torch.permute)
@ -31,6 +26,7 @@ def _prod(dims):
@meta_profiler_function.register(torch.full)
@meta_profiler_function.register(torch.Tensor.cpu)
@meta_profiler_function.register(torch.Tensor.cuda)
@meta_profiler_function.register(torch._assert)
def torch_zero_flops_op(*args, **kwargs) -> Tuple[int, int]:
flops = 0
macs = 0
@ -57,7 +53,7 @@ def torch_max(input: torch.Tensor,
if dim is not None:
shape = list(input.shape)
shape.pop(int(dim))
flops = _prod(shape), macs
flops = reduce(operator.mul, shape), macs
return flops, macs
else:
flops = input.numel()

View File

@ -1,7 +1,10 @@
from .activation_function import *
from .attention import *
from .convolution import *
from .dropout import *
from .embedding import *
from .linear import *
from .normalization import *
from .pooling import *
from .rnn import *
from .torch_op import *

View File

@ -12,6 +12,8 @@ _multiplier = {
torch.nn.ELU: 4,
torch.nn.ReLU6: 2,
torch.nn.GELU: 9,
torch.nn.Hardswish: 5,
torch.nn.Hardsigmoid: 4,
}
@ -23,6 +25,8 @@ _multiplier = {
@meta_profiler_module.register(torch.nn.Tanh)
@meta_profiler_module.register(torch.nn.ReLU6)
@meta_profiler_module.register(torch.nn.PReLU)
@meta_profiler_module.register(torch.nn.Hardswish)
@meta_profiler_module.register(torch.nn.Hardsigmoid)
def torch_nn_non_linear_act(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
flops = input.numel()
macs = 0

View File

@ -0,0 +1,81 @@
from typing import Optional, Tuple
import torch
from ..registry import meta_profiler_module
# TODO: This is hard to compute memory cost
@meta_profiler_module.register(torch.nn.MultiheadAttention)
def torch_nn_msa(self: torch.nn.MultiheadAttention,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_padding_mask: Optional[torch.Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[torch.Tensor] = None,
average_attn_weights: bool = True) -> Tuple[int, int]:
if getattr(self, 'batch_first', False):
batch_size = query.shape[0]
len_idx = 1
else:
batch_size = query.shape[1]
len_idx = 0
dim_idx = 2
qdim = query.shape[dim_idx]
kdim = key.shape[dim_idx]
vdim = value.shape[dim_idx]
qlen = query.shape[len_idx]
klen = key.shape[len_idx]
vlen = value.shape[len_idx]
num_heads = self.num_heads
assert qdim == self.embed_dim
if self.kdim is None:
assert kdim == qdim
if self.vdim is None:
assert vdim == qdim
flops = 0
macs = 0
# Q scaling
flops += qlen * qdim
# Initial projections
flops += 2 * ((qlen * qdim * qdim) # QW
+ (klen * kdim * kdim) # KW
+ (vlen * vdim * vdim) # VW
)
macs += ((qlen * qdim * qdim) # QW
+ (klen * kdim * kdim) # KW
+ (vlen * vdim * vdim) # VW
)
if self.in_proj_bias is not None:
flops += (qlen + klen + vlen) * qdim
# attention heads: scale, matmul, softmax, matmul
qk_head_dim = qdim // num_heads
v_head_dim = vdim // num_heads
head_flops = (
2 * (qlen * klen * qk_head_dim) # QK^T
+ (qlen * klen) # softmax
+ 2 * (qlen * klen * v_head_dim) # AV
)
head_macs = ((qlen * klen * qk_head_dim) # QK^T
+ 2 * (qlen * klen * v_head_dim) # AV
)
flops += num_heads * head_flops
macs += num_heads * head_flops
# final projection, bias is always enabled
flops += qlen * vdim * (vdim + 1)
flops *= batch_size
macs *= batch_size
return flops, macs

View File

@ -1,16 +1,11 @@
import operator
from functools import reduce
import math
from typing import Tuple
import torch
from ..registry import meta_profiler_module
def _prod(dims):
p = 1
for v in dims:
p *= v
return p
@meta_profiler_module.register(torch.nn.Conv1d)
def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, int]:
# the output shape is calculated using the formula stated
@ -23,8 +18,8 @@ def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, in
c_out,
l_out,
)
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
num_elem = _prod(result_shape)
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
num_elem = reduce(operator.mul, result_shape)
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
@ -47,8 +42,8 @@ def torch_nn_conv2d(self: torch.nn.Conv2d, input: torch.Tensor) -> Tuple[int, in
h_out,
w_out,
)
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
num_elem = _prod(result_shape)
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
num_elem = reduce(operator.mul, result_shape)
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
@ -74,8 +69,8 @@ def torch_nn_conv3d(self: torch.nn.Conv3d, input: torch.Tensor) -> Tuple[int, in
h_out,
w_out,
)
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
num_elem = _prod(result_shape)
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
num_elem = reduce(operator.mul, result_shape)
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
@ -95,14 +90,14 @@ def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor
c_out,
l_out,
)
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
num_elem = _prod(
input.shape
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
num_elem = reduce(
operator.mul, input.shape
) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
flops += _prod(result_shape)
flops += reduce(operator.mul, result_shape)
return flops, macs
@ -121,12 +116,12 @@ def torch_nn_convtranspose2d(self: torch.nn.ConvTranspose2d, input: torch.Tensor
h_out,
w_out,
)
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
num_elem = _prod(input.shape)
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
num_elem = reduce(operator.mul, input.shape)
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
flops += _prod(result_shape)
flops += reduce(operator.mul, result_shape)
return flops, macs
@ -148,10 +143,10 @@ def torch_nn_convtranspose3d(self: torch.nn.ConvTranspose3d, input: torch.Tensor
h_out,
w_out,
)
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
num_elem = _prod(input.shape)
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
num_elem = reduce(operator.mul, input.shape)
macs = macs_per_elem * num_elem
flops = 2 * macs
if self.bias is not None:
flops += _prod(result_shape)
flops += reduce(operator.mul, result_shape)
return flops, macs

View File

@ -0,0 +1,11 @@
from typing import Tuple
import torch
from ..registry import meta_profiler_module
@meta_profiler_module.register(torch.nn.Dropout)
def torch_nn_dropout(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
# nn.Embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6)
flops = 0
macs = 0
return flops, macs

View File

@ -4,9 +4,10 @@ from ..registry import meta_profiler_module
@meta_profiler_module.register(torch.nn.Linear)
@meta_profiler_module.register(torch.nn.modules.linear.NonDynamicallyQuantizableLinear)
def torch_nn_linear(self: torch.nn.Linear, input: torch.Tensor) -> Tuple[int, int]:
out_features = self.weight.shape[0]
macs = torch.numel(input) * out_features
macs = input.numel() * out_features
flops = 2 * macs
if self.bias is not None:
flops += self.bias.numel()

View File

@ -1,13 +1,75 @@
from functools import reduce
import operator
import torch
from ..registry import meta_profiler_module
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
# TODO: calculate rnn FLOPs
def _rnn_flops(flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor,
w_hh: torch.Tensor) -> Tuple[int, int]:
# copied from https://github.com/sovrasov/flops-counter.pytorch/blob/master/ptflops/pytorch_ops.py
# matrix matrix mult ih state and internal state
macs += reduce(operator.mul, w_ih.shape)
flops += 2 * reduce(operator.mul, w_ih.shape)
# matrix matrix mult hh state and internal state
macs += reduce(operator.mul, w_hh.shape)
flops += 2 * reduce(operator.mul, w_hh.shape)
if isinstance(module, (torch.nn.RNN, torch.nn.RNNCell)):
# add both operations
flops += module.hidden_size
elif isinstance(module, (torch.nn.GRU, torch.nn.GRUCell)):
# hadamard of r
flops += module.hidden_size
# adding operations from both states
flops += module.hidden_size * 3
# last two hadamard product and add
flops += module.hidden_size * 3
elif isinstance(module, (torch.nn.LSTM, torch.nn.LSTMCell)):
# adding operations from both states
flops += module.hidden_size * 4
# two hadamard product and add for C state
flops += module.hidden_size * 3
# final hadamard
flops += module.hidden_size * 3
return flops, macs
@meta_profiler_module.register(torch.nn.LSTM)
@meta_profiler_module.register(torch.nn.GRU)
@meta_profiler_module.register(torch.nn.RNN)
def torch_nn_rnn(self: torch.nn.Module, input: torch.Tensor, hx: torch.Tensor) -> Tuple[int, int]:
raise NotImplementedError
def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[int, int]:
flops = 0
macs = 0
for i in range(self.num_layers):
w_ih = self.__getattr__('weight_ih_l' + str(i))
w_hh = self.__getattr__('weight_hh_l' + str(i))
flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh)
if self.bias:
b_ih = self.__getattr__('bias_ih_l' + str(i))
b_hh = self.__getattr__('bias_hh_l' + str(i))
flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh)
flops *= reduce(operator.mul, input.shape[:2])
macs *= reduce(operator.mul, input.shape[:2])
if self.bidirectional:
flops *= 2
macs *= 2
return flops, macs
@meta_profiler_module.register(torch.nn.LSTMCell)
@meta_profiler_module.register(torch.nn.GRUCell)
@meta_profiler_module.register(torch.nn.RNNCell)
def torch_nn_rnn(self: torch.nn.RNNCellBase, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[int, int]:
flops = 0
macs = 0
w_ih = self.__getattr__('weight_ih_l')
w_hh = self.__getattr__('weight_hh_l')
flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh)
if self.bias:
b_ih = self.__getattr__('bias_ih_l')
b_hh = self.__getattr__('bias_hh_l')
flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh)
flops *= input.shape[0]
macs *= input.shape[0]
return flops, macs

View File

@ -0,0 +1,11 @@
import operator
import torch
from ..registry import meta_profiler_module
from typing import Optional, Tuple, Union
@meta_profiler_module.register(torch.nn.Flatten)
def torch_nn_flatten(self: torch.nn.Flatten, input: torch.Tensor) -> Tuple[int, int]:
flops = 0
macs = 0
return flops, macs