|
|
|
@ -70,6 +70,19 @@ def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
|
|
|
|
|
return flops
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def baddbmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
|
|
|
|
|
"""
|
|
|
|
|
Count flops for the baddbmm(batch add and batch matmul) operation.
|
|
|
|
|
"""
|
|
|
|
|
# Inputs = [input, batch1, batch2]
|
|
|
|
|
# out = input + batch1 x batch2
|
|
|
|
|
assert len(inputs) == 3, len(inputs)
|
|
|
|
|
n, c, t = inputs[1].shape
|
|
|
|
|
d = inputs[2].shape[-1]
|
|
|
|
|
flops = n * c * t * d
|
|
|
|
|
return flops
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def conv_flop_count(
|
|
|
|
|
x_shape: List[int],
|
|
|
|
|
w_shape: List[int],
|
|
|
|
@ -196,6 +209,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|
|
|
|
aten.matmul.default: matmul_flop_jit,
|
|
|
|
|
aten.addmm.default: addmm_flop_jit,
|
|
|
|
|
aten.bmm.default: bmm_flop_jit,
|
|
|
|
|
aten.baddbmm.default: baddbmm_flop_jit,
|
|
|
|
|
|
|
|
|
|
# convolution
|
|
|
|
|
aten.convolution.default: conv_flop_jit,
|
|
|
|
@ -209,6 +223,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|
|
|
|
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
|
|
|
|
|
aten.native_layer_norm.default: norm_flop_counter(2, 0),
|
|
|
|
|
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
|
|
|
|
|
aten.native_group_norm.default: norm_flop_counter(2, 0),
|
|
|
|
|
aten.native_group_norm_backward.default: norm_flop_counter(2, 0),
|
|
|
|
|
|
|
|
|
|
# pooling
|
|
|
|
|
aten.avg_pool1d.default: elementwise_flop_counter(1, 0),
|
|
|
|
@ -230,6 +246,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|
|
|
|
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
|
|
|
|
|
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
|
|
|
|
|
aten.embedding.default: elementwise_flop_counter(1, 0),
|
|
|
|
|
aten.upsample_nearest2d.vec: elementwise_flop_counter(0, 1),
|
|
|
|
|
aten.upsample_nearest2d_backward.vec: elementwise_flop_counter(0, 1),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
elementwise_flop_aten = [
|
|
|
|
@ -251,6 +269,9 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|
|
|
|
aten.mean.dim,
|
|
|
|
|
aten.sub.Tensor,
|
|
|
|
|
aten.sub_.Tensor,
|
|
|
|
|
aten.exp.default,
|
|
|
|
|
aten.sin.default,
|
|
|
|
|
aten.cos.default,
|
|
|
|
|
|
|
|
|
|
# activation op
|
|
|
|
|
aten.hardswish.default,
|
|
|
|
|