mirror of https://github.com/hpcaitech/ColossalAI
[fx] add more op patches for profiler and error message for unsupported ops. (#1495)
* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] merge development into main (#1) * [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen. * [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. * [fx] fix lowercase naming conventions. * [fx] simplify test for ckpt. * [fx] add rules to linearize computation graphs for searching. (#2) * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] merge development into main (#1) * [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen. * [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. * [fx] fix lowercase naming conventions. * [fx] simplify test for ckpt. * [fx] fix test and algorithm bugs in activation checkpointing. * [fx] polish ckpt_test. * [fx] add rules to linearize computation graphs for searching. * [fx] remove chen_sqrt for sake of simplicity * [fx] remove chen_sqrt for sake of simplicity * [fx] remove chen_sqrt for sake of simplicity * [fx] remove chen_sqrt for sake of simplicity * [fx] fix inconsistencies. * [fx] fix MetaInfoProp. * [fx] fix MetaInfoProp. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] fix error in tests. * [fx] unfix bug. * [fx] unfix bug. * [fx] patch more modules and functions. * [fx] change name of utils.py to profiler.py * [fx] add profiler for rnn. * [fx] add profiler for rnn. * [fx] polish and add more patch for profiler. * [fx] polish and add more patch for profiler.pull/1499/head
parent
413c053453
commit
09c023bee2
|
@ -1,4 +1,4 @@
|
|||
from .registry import *
|
||||
from .profiler_function import *
|
||||
from .profiler_module import *
|
||||
from .utils import *
|
||||
from .profiler import *
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from functools import partial
|
||||
from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos
|
||||
from typing import Callable, NamedTuple, Any, Dict, Tuple
|
||||
from typing import Callable, List, NamedTuple, Any, Dict, Tuple, Union
|
||||
import torch
|
||||
from torch.fx.node import Argument, Target
|
||||
from torch.fx.node import Argument, Target, map_aggregate
|
||||
from torch.fx._compatibility import compatibility
|
||||
from colossalai.fx.tracer.meta_patch import meta_patched_function, meta_patched_module
|
||||
from . import meta_profiler_function, meta_profiler_module
|
||||
|
@ -12,6 +12,30 @@ __all__ = [
|
|||
'calculate_param_size'
|
||||
]
|
||||
|
||||
CALL_FUNCTION_MSG = \
|
||||
"""
|
||||
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
|
||||
from colossalai.fx.profiler import meta_profiler_function
|
||||
|
||||
@meta_profiler_function.register(YOUR_FUNCTION)
|
||||
def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]:
|
||||
flops = ...
|
||||
macs = ...
|
||||
return flops, macs
|
||||
"""
|
||||
CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}'
|
||||
CALL_MODULE_MSG = \
|
||||
"""
|
||||
Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n
|
||||
from colossalai.fx.profiler import meta_profiler_module
|
||||
|
||||
@meta_profiler_module.register(YOUR_MODULE)
|
||||
def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
|
||||
flops = ...
|
||||
macs = ...
|
||||
return flops, macs
|
||||
"""
|
||||
|
||||
# TODO fill out the inplace ops
|
||||
INPLACE_OPS = [
|
||||
add,
|
||||
|
@ -22,18 +46,30 @@ INPLACE_OPS = [
|
|||
pos,
|
||||
getitem,
|
||||
setitem,
|
||||
getattr,
|
||||
torch.Tensor.cpu,
|
||||
]
|
||||
|
||||
# TODO check that call_methods are indeed inplace
|
||||
# TODO: list all call_methods that are inplace here
|
||||
INPLACE_METHOD = [
|
||||
'transpose',
|
||||
'permute',
|
||||
# TODO: reshape may return a copy of the data if the data is not contiguous
|
||||
'reshape',
|
||||
'dim',
|
||||
'flatten',
|
||||
]
|
||||
|
||||
# TODO: list all call_methods that are not inplace here
|
||||
NON_INPLACE_METHOD = [
|
||||
'expand',
|
||||
'mean',
|
||||
]
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
class MetaProfile(NamedTuple):
|
||||
|
||||
# MetaProfile is a structure containing pertinent information
|
||||
# about a node within a torch.fx GraphModule.
|
||||
|
||||
|
@ -43,9 +79,14 @@ class MetaProfile(NamedTuple):
|
|||
macs: int
|
||||
|
||||
|
||||
def calculate_activation_size(activation: any) -> int:
|
||||
"""
|
||||
Calculate activation size of a node.
|
||||
def calculate_activation_size(activation: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
|
||||
"""Calculate activation size of a node.
|
||||
|
||||
Args:
|
||||
activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional`
|
||||
|
||||
Returns:
|
||||
int: The activation size
|
||||
"""
|
||||
activation_size = 0
|
||||
if isinstance(activation, torch.Tensor):
|
||||
|
@ -53,15 +94,20 @@ def calculate_activation_size(activation: any) -> int:
|
|||
elif isinstance(activation, dict):
|
||||
value_list = [v for _, v in activation.items()]
|
||||
activation_size += calculate_activation_size(value_list)
|
||||
else:
|
||||
elif isinstance(activation, tuple) or isinstance(activation, list):
|
||||
for element in activation:
|
||||
activation_size += calculate_activation_size(element)
|
||||
return activation_size
|
||||
|
||||
|
||||
def calculate_param_size(mod: torch.nn.Module) -> int:
|
||||
"""
|
||||
Calculate param size of a node.
|
||||
"""Calculate param size of a node.
|
||||
|
||||
Args:
|
||||
mod (torch.nn.Module): The target `torch.nn.Module`
|
||||
|
||||
Returns:
|
||||
int: The param size
|
||||
"""
|
||||
param_size = 0
|
||||
for param in mod.parameters():
|
||||
|
@ -78,17 +124,21 @@ def profile_function(target: 'Target') -> Callable:
|
|||
You may only use tensors with `device=meta` for this wrapped function.
|
||||
Only original `torch.nn.functional` are available.
|
||||
|
||||
Usage:
|
||||
input = torch.rand(100, 100, 100, 100, device='meta')
|
||||
func = torch.nn.functional.relu
|
||||
output, profile = profile_function(func)(input, inplace=False)
|
||||
print(f"Profiling function {func},")
|
||||
print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
|
||||
Examples:
|
||||
>> input = torch.rand(100, 100, 100, 100, device='meta')
|
||||
>> func = torch.nn.functional.relu
|
||||
>> output, profile = profile_function(func)(input, inplace=False)
|
||||
>> print(f"Profiling function {func},")
|
||||
>> print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
|
||||
Profiling function <function relu at 0x7fcdd0258d30>,
|
||||
Param size: 0.000 MB, Activation size: 381.470 MB, 100000000 FLOPs, 0 MACs
|
||||
"""
|
||||
|
||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
assert meta_profiler_function.has(target) or meta_profiler_function.has(
|
||||
target.__name__), f"Colossal-AI hasn't supported profiling for {target}, you might manually patch it."
|
||||
target.__name__), CALL_FUNCTION_MSG.format(target)
|
||||
# ensure all arguments satisfy `device='meta'`
|
||||
args, kwargs = map_aggregate([args, kwargs], lambda a: a.to('meta') if isinstance(a, torch.Tensor) else a)
|
||||
|
||||
# call_function has no parameters
|
||||
param_size = 0
|
||||
|
@ -127,14 +177,17 @@ def profile_method(target: 'Target') -> Callable:
|
|||
# args[0] is the `self` object for this method call
|
||||
self_obj, *args_tail = args
|
||||
|
||||
# Execute the method and return the result
|
||||
# execute the method and return the result
|
||||
assert isinstance(target, str), f'{target} instance is not str.'
|
||||
result = getattr(self_obj, target)(*args_tail, **kwargs)
|
||||
assert target in INPLACE_METHOD, f'Please check {target} is an inplace method. If so, add target to INPLACE_METHOD={INPLACE_METHOD}.'
|
||||
|
||||
# ensure all arguments satisfy `device='meta'`
|
||||
map_aggregate([args, kwargs], lambda a: a.to('meta') if isinstance(a, torch.Tensor) else a)
|
||||
result = getattr(self_obj, target)(*args_tail, **kwargs)
|
||||
assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format(
|
||||
target, INPLACE_METHOD, NON_INPLACE_METHOD)
|
||||
# call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs.
|
||||
param_size = 0
|
||||
activation_size = 0
|
||||
activation_size = 0 if target in INPLACE_METHOD else calculate_activation_size(result)
|
||||
flops = 0
|
||||
macs = 0
|
||||
return result, MetaProfile(param_size, activation_size, flops, macs)
|
||||
|
@ -151,17 +204,20 @@ def profile_module(module: torch.nn.Module) -> Callable:
|
|||
You may only use tensors with `device=meta` for this wrapped function.
|
||||
Only original `torch.nn` are available.
|
||||
|
||||
Usage:
|
||||
input = torch.rand(4, 3, 224, 224, device='meta')
|
||||
mod = torch.nn.Conv2d(3, 128, 3)
|
||||
output, profile = profile_module(mod)(input)
|
||||
print(f"Profiling function {mod},")
|
||||
print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
|
||||
Example:
|
||||
>> input = torch.rand(4, 3, 224, 224, device='meta')
|
||||
>> mod = torch.nn.Conv2d(3, 128, 3)
|
||||
>> output, profile = profile_module(mod)(input)
|
||||
>> print(f"Profiling function {mod},")
|
||||
>> print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs")
|
||||
Profiling function Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1)),
|
||||
Param size: 0.014 MB, Activation size: 96.258 MB, 1387837440 FLOPs, 681302016 MACs
|
||||
"""
|
||||
|
||||
def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any:
|
||||
assert meta_profiler_module.has(
|
||||
type(module)), f"Colossal-AI hasn't supported profiling for {module}, you might manually patch it."
|
||||
assert meta_profiler_module.has(type(module)), CALL_MODULE_MSG.format(type(module))
|
||||
# ensure all arguments satisfy `device='meta'`
|
||||
map_aggregate([args, kwargs], lambda a: a.to('meta') if isinstance(a, torch.Tensor) else a)
|
||||
param_size = calculate_param_size(module)
|
||||
activation_size = 0
|
||||
result = func(*args, **kwargs)
|
|
@ -12,6 +12,8 @@ _multiplier = {
|
|||
torch.nn.functional.elu: 4,
|
||||
torch.nn.functional.relu6: 2,
|
||||
torch.nn.functional.gelu: 9,
|
||||
torch.nn.functional.hardswish: 5,
|
||||
torch.nn.functional.hardsigmoid: 4,
|
||||
}
|
||||
|
||||
|
||||
|
@ -23,6 +25,8 @@ _multiplier = {
|
|||
@meta_profiler_function.register(torch.nn.functional.relu)
|
||||
@meta_profiler_function.register(torch.nn.functional.sigmoid)
|
||||
@meta_profiler_function.register(torch.nn.functional.tanh)
|
||||
@meta_profiler_function.register(torch.nn.functional.hardswish)
|
||||
@meta_profiler_function.register(torch.nn.functional.hardsigmoid)
|
||||
def torch_nn_func_non_linear_act(input: torch.Tensor, inplace: bool = False) -> Tuple[int, int]:
|
||||
flops = input.numel()
|
||||
macs = 0
|
||||
|
|
|
@ -1,24 +1,19 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
import torch
|
||||
from ..registry import meta_profiler_function
|
||||
|
||||
|
||||
def _prod(dims):
|
||||
p = 1
|
||||
for v in dims:
|
||||
p *= v
|
||||
return p
|
||||
|
||||
|
||||
def _elementwise_flops_compute(input, other):
|
||||
# copied from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L763
|
||||
if not torch.is_tensor(input):
|
||||
if torch.is_tensor(other):
|
||||
return _prod(other.shape), 0
|
||||
return reduce(operator.mul, other.shape), 0
|
||||
else:
|
||||
return 1, 0
|
||||
elif not torch.is_tensor(other):
|
||||
return _prod(input.shape), 0
|
||||
return reduce(operator.mul, input.shape), 0
|
||||
else:
|
||||
dim_input = len(input.shape)
|
||||
dim_other = len(other.shape)
|
||||
|
@ -32,17 +27,24 @@ def _elementwise_flops_compute(input, other):
|
|||
final_shape.append(in_i)
|
||||
else:
|
||||
final_shape.append(ot_i)
|
||||
flops = _prod(final_shape)
|
||||
flops = reduce(operator.mul, final_shape)
|
||||
return flops, 0
|
||||
|
||||
|
||||
@meta_profiler_function.register(torch.add)
|
||||
@meta_profiler_function.register(torch.eq)
|
||||
@meta_profiler_function.register(torch.sub)
|
||||
@meta_profiler_function.register(torch.mul)
|
||||
@meta_profiler_function.register(torch.floor_divide)
|
||||
@meta_profiler_function.register('add') # for built-in op +
|
||||
@meta_profiler_function.register('iadd') # for built-in op +=
|
||||
@meta_profiler_function.register('eq') # for built-in op =
|
||||
@meta_profiler_function.register('sub') # for built-in op -
|
||||
@meta_profiler_function.register('isub') # for built-in op -=
|
||||
@meta_profiler_function.register('mul') # for built-in op *
|
||||
@meta_profiler_function.register('imul') # for built-in op *=
|
||||
@meta_profiler_function.register('floordiv') # for built-in op //
|
||||
@meta_profiler_function.register('ifloordiv') # for built-in op //=
|
||||
def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
|
||||
return _elementwise_flops_compute(input, other)
|
||||
|
||||
|
@ -58,14 +60,14 @@ def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = N
|
|||
@meta_profiler_function.register('matmul') # for built-in op @
|
||||
@meta_profiler_function.register(torch.Tensor.matmul)
|
||||
def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
|
||||
macs = _prod(input.shape) * other.shape[-1]
|
||||
macs = reduce(operator.mul, input.shape) * other.shape[-1]
|
||||
flops = 2 * macs
|
||||
return flops, macs
|
||||
|
||||
|
||||
@meta_profiler_function.register(torch.bmm)
|
||||
def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]:
|
||||
macs = _prod(input.shape) * other.shape[-1]
|
||||
macs = reduce(operator.mul, input.shape) * other.shape[-1]
|
||||
flops = 2 * macs
|
||||
return flops, macs
|
||||
|
||||
|
|
|
@ -10,3 +10,10 @@ def operator_getitem(a: Any, b: Any) -> Tuple[int, int]:
|
|||
flops = 0
|
||||
macs = 0
|
||||
return flops, macs
|
||||
|
||||
|
||||
@meta_profiler_function.register(getattr)
|
||||
def python_getattr(a: Any, b: Any) -> Tuple[int, int]:
|
||||
flops = 0
|
||||
macs = 0
|
||||
return flops, macs
|
||||
|
|
|
@ -1,15 +1,10 @@
|
|||
from functools import reduce
|
||||
import operator
|
||||
from typing import Any, Optional, Tuple
|
||||
import torch
|
||||
from ..registry import meta_profiler_function
|
||||
|
||||
|
||||
def _prod(dims):
|
||||
p = 1
|
||||
for v in dims:
|
||||
p *= v
|
||||
return p
|
||||
|
||||
|
||||
@meta_profiler_function.register(torch.arange)
|
||||
@meta_profiler_function.register(torch.finfo)
|
||||
@meta_profiler_function.register(torch.permute)
|
||||
|
@ -31,6 +26,7 @@ def _prod(dims):
|
|||
@meta_profiler_function.register(torch.full)
|
||||
@meta_profiler_function.register(torch.Tensor.cpu)
|
||||
@meta_profiler_function.register(torch.Tensor.cuda)
|
||||
@meta_profiler_function.register(torch._assert)
|
||||
def torch_zero_flops_op(*args, **kwargs) -> Tuple[int, int]:
|
||||
flops = 0
|
||||
macs = 0
|
||||
|
@ -57,7 +53,7 @@ def torch_max(input: torch.Tensor,
|
|||
if dim is not None:
|
||||
shape = list(input.shape)
|
||||
shape.pop(int(dim))
|
||||
flops = _prod(shape), macs
|
||||
flops = reduce(operator.mul, shape), macs
|
||||
return flops, macs
|
||||
else:
|
||||
flops = input.numel()
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
from .activation_function import *
|
||||
from .attention import *
|
||||
from .convolution import *
|
||||
from .dropout import *
|
||||
from .embedding import *
|
||||
from .linear import *
|
||||
from .normalization import *
|
||||
from .pooling import *
|
||||
from .rnn import *
|
||||
from .torch_op import *
|
||||
|
|
|
@ -12,6 +12,8 @@ _multiplier = {
|
|||
torch.nn.ELU: 4,
|
||||
torch.nn.ReLU6: 2,
|
||||
torch.nn.GELU: 9,
|
||||
torch.nn.Hardswish: 5,
|
||||
torch.nn.Hardsigmoid: 4,
|
||||
}
|
||||
|
||||
|
||||
|
@ -23,6 +25,8 @@ _multiplier = {
|
|||
@meta_profiler_module.register(torch.nn.Tanh)
|
||||
@meta_profiler_module.register(torch.nn.ReLU6)
|
||||
@meta_profiler_module.register(torch.nn.PReLU)
|
||||
@meta_profiler_module.register(torch.nn.Hardswish)
|
||||
@meta_profiler_module.register(torch.nn.Hardsigmoid)
|
||||
def torch_nn_non_linear_act(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
|
||||
flops = input.numel()
|
||||
macs = 0
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
from typing import Optional, Tuple
|
||||
import torch
|
||||
from ..registry import meta_profiler_module
|
||||
|
||||
|
||||
# TODO: This is hard to compute memory cost
|
||||
@meta_profiler_module.register(torch.nn.MultiheadAttention)
|
||||
def torch_nn_msa(self: torch.nn.MultiheadAttention,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_padding_mask: Optional[torch.Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
average_attn_weights: bool = True) -> Tuple[int, int]:
|
||||
if getattr(self, 'batch_first', False):
|
||||
batch_size = query.shape[0]
|
||||
len_idx = 1
|
||||
else:
|
||||
batch_size = query.shape[1]
|
||||
len_idx = 0
|
||||
dim_idx = 2
|
||||
|
||||
qdim = query.shape[dim_idx]
|
||||
kdim = key.shape[dim_idx]
|
||||
vdim = value.shape[dim_idx]
|
||||
|
||||
qlen = query.shape[len_idx]
|
||||
klen = key.shape[len_idx]
|
||||
vlen = value.shape[len_idx]
|
||||
|
||||
num_heads = self.num_heads
|
||||
assert qdim == self.embed_dim
|
||||
|
||||
if self.kdim is None:
|
||||
assert kdim == qdim
|
||||
if self.vdim is None:
|
||||
assert vdim == qdim
|
||||
|
||||
flops = 0
|
||||
macs = 0
|
||||
|
||||
# Q scaling
|
||||
flops += qlen * qdim
|
||||
|
||||
# Initial projections
|
||||
flops += 2 * ((qlen * qdim * qdim) # QW
|
||||
+ (klen * kdim * kdim) # KW
|
||||
+ (vlen * vdim * vdim) # VW
|
||||
)
|
||||
|
||||
macs += ((qlen * qdim * qdim) # QW
|
||||
+ (klen * kdim * kdim) # KW
|
||||
+ (vlen * vdim * vdim) # VW
|
||||
)
|
||||
|
||||
if self.in_proj_bias is not None:
|
||||
flops += (qlen + klen + vlen) * qdim
|
||||
|
||||
# attention heads: scale, matmul, softmax, matmul
|
||||
qk_head_dim = qdim // num_heads
|
||||
v_head_dim = vdim // num_heads
|
||||
|
||||
head_flops = (
|
||||
2 * (qlen * klen * qk_head_dim) # QK^T
|
||||
+ (qlen * klen) # softmax
|
||||
+ 2 * (qlen * klen * v_head_dim) # AV
|
||||
)
|
||||
head_macs = ((qlen * klen * qk_head_dim) # QK^T
|
||||
+ 2 * (qlen * klen * v_head_dim) # AV
|
||||
)
|
||||
|
||||
flops += num_heads * head_flops
|
||||
macs += num_heads * head_flops
|
||||
|
||||
# final projection, bias is always enabled
|
||||
flops += qlen * vdim * (vdim + 1)
|
||||
|
||||
flops *= batch_size
|
||||
macs *= batch_size
|
||||
return flops, macs
|
|
@ -1,16 +1,11 @@
|
|||
import operator
|
||||
from functools import reduce
|
||||
import math
|
||||
from typing import Tuple
|
||||
import torch
|
||||
from ..registry import meta_profiler_module
|
||||
|
||||
|
||||
def _prod(dims):
|
||||
p = 1
|
||||
for v in dims:
|
||||
p *= v
|
||||
return p
|
||||
|
||||
|
||||
@meta_profiler_module.register(torch.nn.Conv1d)
|
||||
def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, int]:
|
||||
# the output shape is calculated using the formula stated
|
||||
|
@ -23,8 +18,8 @@ def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, in
|
|||
c_out,
|
||||
l_out,
|
||||
)
|
||||
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
|
||||
num_elem = _prod(result_shape)
|
||||
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
|
||||
num_elem = reduce(operator.mul, result_shape)
|
||||
macs = macs_per_elem * num_elem
|
||||
flops = 2 * macs
|
||||
if self.bias is not None:
|
||||
|
@ -47,8 +42,8 @@ def torch_nn_conv2d(self: torch.nn.Conv2d, input: torch.Tensor) -> Tuple[int, in
|
|||
h_out,
|
||||
w_out,
|
||||
)
|
||||
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
|
||||
num_elem = _prod(result_shape)
|
||||
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
|
||||
num_elem = reduce(operator.mul, result_shape)
|
||||
macs = macs_per_elem * num_elem
|
||||
flops = 2 * macs
|
||||
if self.bias is not None:
|
||||
|
@ -74,8 +69,8 @@ def torch_nn_conv3d(self: torch.nn.Conv3d, input: torch.Tensor) -> Tuple[int, in
|
|||
h_out,
|
||||
w_out,
|
||||
)
|
||||
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
|
||||
num_elem = _prod(result_shape)
|
||||
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
|
||||
num_elem = reduce(operator.mul, result_shape)
|
||||
macs = macs_per_elem * num_elem
|
||||
flops = 2 * macs
|
||||
if self.bias is not None:
|
||||
|
@ -95,14 +90,14 @@ def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor
|
|||
c_out,
|
||||
l_out,
|
||||
)
|
||||
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
|
||||
num_elem = _prod(
|
||||
input.shape
|
||||
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
|
||||
num_elem = reduce(
|
||||
operator.mul, input.shape
|
||||
) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604
|
||||
macs = macs_per_elem * num_elem
|
||||
flops = 2 * macs
|
||||
if self.bias is not None:
|
||||
flops += _prod(result_shape)
|
||||
flops += reduce(operator.mul, result_shape)
|
||||
return flops, macs
|
||||
|
||||
|
||||
|
@ -121,12 +116,12 @@ def torch_nn_convtranspose2d(self: torch.nn.ConvTranspose2d, input: torch.Tensor
|
|||
h_out,
|
||||
w_out,
|
||||
)
|
||||
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
|
||||
num_elem = _prod(input.shape)
|
||||
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
|
||||
num_elem = reduce(operator.mul, input.shape)
|
||||
macs = macs_per_elem * num_elem
|
||||
flops = 2 * macs
|
||||
if self.bias is not None:
|
||||
flops += _prod(result_shape)
|
||||
flops += reduce(operator.mul, result_shape)
|
||||
return flops, macs
|
||||
|
||||
|
||||
|
@ -148,10 +143,10 @@ def torch_nn_convtranspose3d(self: torch.nn.ConvTranspose3d, input: torch.Tensor
|
|||
h_out,
|
||||
w_out,
|
||||
)
|
||||
macs_per_elem = _prod(self.kernel_size) * c_in // self.groups
|
||||
num_elem = _prod(input.shape)
|
||||
macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups
|
||||
num_elem = reduce(operator.mul, input.shape)
|
||||
macs = macs_per_elem * num_elem
|
||||
flops = 2 * macs
|
||||
if self.bias is not None:
|
||||
flops += _prod(result_shape)
|
||||
flops += reduce(operator.mul, result_shape)
|
||||
return flops, macs
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
from typing import Tuple
|
||||
import torch
|
||||
from ..registry import meta_profiler_module
|
||||
|
||||
|
||||
@meta_profiler_module.register(torch.nn.Dropout)
|
||||
def torch_nn_dropout(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]:
|
||||
# nn.Embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6)
|
||||
flops = 0
|
||||
macs = 0
|
||||
return flops, macs
|
|
@ -4,9 +4,10 @@ from ..registry import meta_profiler_module
|
|||
|
||||
|
||||
@meta_profiler_module.register(torch.nn.Linear)
|
||||
@meta_profiler_module.register(torch.nn.modules.linear.NonDynamicallyQuantizableLinear)
|
||||
def torch_nn_linear(self: torch.nn.Linear, input: torch.Tensor) -> Tuple[int, int]:
|
||||
out_features = self.weight.shape[0]
|
||||
macs = torch.numel(input) * out_features
|
||||
macs = input.numel() * out_features
|
||||
flops = 2 * macs
|
||||
if self.bias is not None:
|
||||
flops += self.bias.numel()
|
||||
|
|
|
@ -1,13 +1,75 @@
|
|||
from functools import reduce
|
||||
import operator
|
||||
import torch
|
||||
from ..registry import meta_profiler_module
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
||||
# TODO: calculate rnn FLOPs
|
||||
def _rnn_flops(flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor,
|
||||
w_hh: torch.Tensor) -> Tuple[int, int]:
|
||||
# copied from https://github.com/sovrasov/flops-counter.pytorch/blob/master/ptflops/pytorch_ops.py
|
||||
|
||||
# matrix matrix mult ih state and internal state
|
||||
macs += reduce(operator.mul, w_ih.shape)
|
||||
flops += 2 * reduce(operator.mul, w_ih.shape)
|
||||
# matrix matrix mult hh state and internal state
|
||||
macs += reduce(operator.mul, w_hh.shape)
|
||||
flops += 2 * reduce(operator.mul, w_hh.shape)
|
||||
if isinstance(module, (torch.nn.RNN, torch.nn.RNNCell)):
|
||||
# add both operations
|
||||
flops += module.hidden_size
|
||||
elif isinstance(module, (torch.nn.GRU, torch.nn.GRUCell)):
|
||||
# hadamard of r
|
||||
flops += module.hidden_size
|
||||
# adding operations from both states
|
||||
flops += module.hidden_size * 3
|
||||
# last two hadamard product and add
|
||||
flops += module.hidden_size * 3
|
||||
elif isinstance(module, (torch.nn.LSTM, torch.nn.LSTMCell)):
|
||||
# adding operations from both states
|
||||
flops += module.hidden_size * 4
|
||||
# two hadamard product and add for C state
|
||||
flops += module.hidden_size * 3
|
||||
# final hadamard
|
||||
flops += module.hidden_size * 3
|
||||
return flops, macs
|
||||
|
||||
|
||||
@meta_profiler_module.register(torch.nn.LSTM)
|
||||
@meta_profiler_module.register(torch.nn.GRU)
|
||||
@meta_profiler_module.register(torch.nn.RNN)
|
||||
def torch_nn_rnn(self: torch.nn.Module, input: torch.Tensor, hx: torch.Tensor) -> Tuple[int, int]:
|
||||
raise NotImplementedError
|
||||
def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[int, int]:
|
||||
flops = 0
|
||||
macs = 0
|
||||
for i in range(self.num_layers):
|
||||
w_ih = self.__getattr__('weight_ih_l' + str(i))
|
||||
w_hh = self.__getattr__('weight_hh_l' + str(i))
|
||||
flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh)
|
||||
if self.bias:
|
||||
b_ih = self.__getattr__('bias_ih_l' + str(i))
|
||||
b_hh = self.__getattr__('bias_hh_l' + str(i))
|
||||
flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh)
|
||||
flops *= reduce(operator.mul, input.shape[:2])
|
||||
macs *= reduce(operator.mul, input.shape[:2])
|
||||
if self.bidirectional:
|
||||
flops *= 2
|
||||
macs *= 2
|
||||
return flops, macs
|
||||
|
||||
|
||||
@meta_profiler_module.register(torch.nn.LSTMCell)
|
||||
@meta_profiler_module.register(torch.nn.GRUCell)
|
||||
@meta_profiler_module.register(torch.nn.RNNCell)
|
||||
def torch_nn_rnn(self: torch.nn.RNNCellBase, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[int, int]:
|
||||
flops = 0
|
||||
macs = 0
|
||||
w_ih = self.__getattr__('weight_ih_l')
|
||||
w_hh = self.__getattr__('weight_hh_l')
|
||||
flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh)
|
||||
if self.bias:
|
||||
b_ih = self.__getattr__('bias_ih_l')
|
||||
b_hh = self.__getattr__('bias_hh_l')
|
||||
flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh)
|
||||
flops *= input.shape[0]
|
||||
macs *= input.shape[0]
|
||||
return flops, macs
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
import operator
|
||||
import torch
|
||||
from ..registry import meta_profiler_module
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
||||
@meta_profiler_module.register(torch.nn.Flatten)
|
||||
def torch_nn_flatten(self: torch.nn.Flatten, input: torch.Tensor) -> Tuple[int, int]:
|
||||
flops = 0
|
||||
macs = 0
|
||||
return flops, macs
|
Loading…
Reference in New Issue