# adopted from https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/jit_handles.py # ideas from https://pastebin.com/AkvAyJBw # and https://dev-discuss.pytorch.org/t/the-ideal-pytorch-flop-counter-with-torch-dispatch/505 import operator from collections import defaultdict from contextlib import contextmanager from enum import Enum, auto from functools import partial, reduce from numbers import Number from typing import Any, Callable, List, Optional, Union import torch from packaging import version from torch.utils._pytree import tree_map from .meta_tensor import MetaTensor aten = torch.ops.aten class Phase(Enum): FWD = auto() BWD = auto() def normalize_tuple(x): if not isinstance(x, tuple): return (x,) return x def _format_flops(flop): K = 1e3 M = 1e6 B = 1e9 T = 1e12 if flop < K: return f'{flop:.2f}' elif flop < M: return f'{flop / K:.2f}K' elif flop < B: return f'{flop / M:.2f}M' elif flop < T: return f'{flop / B:.2f}B' else: return f'{flop / T:.2f}T' def flop_count(module: Union[torch.nn.Module, Callable] = None, *args, verbose: bool = False, **kwargs) -> Number: """ Count the number of floating point operations in a model. Ideas from https://pastebin.com/AkvAyJBw. Args: module (torch.nn.Module): A PyTorch model. *args: Input arguments to the model. verbose (bool): If True, print the number of flops for each module. **kwargs: Input keyword arguments to the model. Returns: Number: The total number of floating point operations (FWD + BWD). """ maybe_inplace = (getattr(module, 'inplace', False) or kwargs.get('inplace', False) or getattr(module, '__name__', None) in ('add_', 'mul_', 'div_', 'sub_')) class DummyModule(torch.nn.Module): def __init__(self, func): super().__init__() self.func = func self.__name__ = func.__name__ def forward(self, *args, **kwargs): return self.func(*args, **kwargs) total_flop_count = {Phase.FWD: 0, Phase.BWD: 0} flop_counts = defaultdict(lambda: defaultdict(int)) parents = ['Global'] module = module if isinstance(module, torch.nn.Module) else DummyModule(module) class FlopTensor(MetaTensor): _tensor: torch.Tensor def __repr__(self): name = 'FlopParameter' if getattr(self, '_is_param', False) else 'FlopTensor' if self.grad_fn: return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})" return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})" @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # no_dispatch is only needed if you use enable_python_mode. # It prevents infinite recursion. rs = super().__torch_dispatch__(func, types, args, kwargs) outs = normalize_tuple(rs) if func in flop_mapping: nonlocal flop_counts, total_flop_count flop_count = flop_mapping[func](args, outs) for par in parents: flop_counts[par][func.__name__] += flop_count total_flop_count[cur_phase] += flop_count def wrap(x): if isinstance(x, MetaTensor): x = FlopTensor(x) return x rs = tree_map(wrap, rs) return rs def is_autogradable(x): return isinstance(x, torch.Tensor) and x.is_floating_point() def create_backwards_push(name): class PushState(torch.autograd.Function): @staticmethod def forward(ctx, *args): args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args) if len(args) == 1: return args[0] return args @staticmethod def backward(ctx, *grad_outs): nonlocal parents parents.append(name) return grad_outs return PushState.apply def create_backwards_pop(name): class PopState(torch.autograd.Function): @staticmethod def forward(ctx, *args): args = tree_map(lambda x: x.clone() if isinstance(x, torch.Tensor) else x, args) if len(args) == 1: return args[0] return args @staticmethod def backward(ctx, *grad_outs): nonlocal parents assert (parents[-1] == name) parents.pop() return grad_outs return PopState.apply def enter_module(name): def f(module, inputs): nonlocal parents parents.append(name) inputs = normalize_tuple(inputs) out = create_backwards_pop(name)(*inputs) return out return f def exit_module(name): def f(module, inputs, outputs): nonlocal parents assert (parents[-1] == name) parents.pop() outputs = normalize_tuple(outputs) return create_backwards_push(name)(*outputs) return f @contextmanager def instrument_module(mod): registered = [] for name, module in dict(mod.named_children()).items(): registered.append(module.register_forward_pre_hook(enter_module(name))) registered.append(module.register_forward_hook(exit_module(name))) yield for handle in registered: handle.remove() def display_flops(): for mod in flop_counts.keys(): print(f"Module: ", mod) for k, v in flop_counts[mod].items(): print('\t', k, _format_flops(v)) print() def detach_variables(r): if isinstance(r, torch.Tensor): requires_grad = r.requires_grad r = r.detach() r.requires_grad = requires_grad return r def wrap(r): if isinstance(r, torch.Tensor): data_ptr_fn = getattr(r, '_tensor', r).data_ptr r = FlopTensor(detach_variables(r)) if maybe_inplace: r = r + 0 r._tensor.data_ptr = data_ptr_fn return r with instrument_module(module): cur_phase = Phase.FWD rst = module(*tree_map(wrap, args), **tree_map(wrap, kwargs)) rst = tuple(r for r in normalize_tuple(rst) if is_autogradable(r) and r.requires_grad) cur_phase = Phase.BWD if rst: grad = [torch.zeros_like(t) for t in rst] torch.autograd.backward( rst, grad, ) if verbose: display_flops() return total_flop_count[Phase.FWD], total_flop_count[Phase.BWD] def matmul_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: """ Count flops for matmul. """ # Inputs should be a list of length 2. # Inputs contains the shapes of two matrices. input_shapes = [v.shape for v in inputs] assert len(input_shapes) == 2, input_shapes # There are three cases: 1) gemm, 2) gemv, 3) dot if all(len(shape) == 2 for shape in input_shapes): # gemm assert input_shapes[0][-1] == input_shapes[1][-2], input_shapes elif all(len(shape) == 1 for shape in input_shapes): # dot assert input_shapes[0][0] == input_shapes[1][0], input_shapes # expand shape input_shapes[0] = torch.Size([1, input_shapes[0][0]]) input_shapes[1] = torch.Size([input_shapes[1][0], 1]) else: # gemv if len(input_shapes[0]) == 1: assert input_shapes[0][0] == input_shapes[1][-2], input_shapes input_shapes.reverse() else: assert input_shapes[1][0] == input_shapes[0][-1], input_shapes # expand the shape of the vector to [batch size, 1] input_shapes[-1] = torch.Size([input_shapes[-1][-1], 1]) flops = reduce(operator.mul, input_shapes[0]) * input_shapes[-1][-1] return flops def addmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: """ Count flops for fully connected layers. """ # Count flop for nn.Linear # inputs is a list of length 3. input_shapes = [v.shape for v in inputs[1:3]] # input_shapes[0]: [batch size, input feature dimension] # input_shapes[1]: [input feature dimension, output feature dimension] assert len(input_shapes[0]) == 2, input_shapes[0] assert len(input_shapes[1]) == 2, input_shapes[1] batch_size, input_dim = input_shapes[0] output_dim = input_shapes[1][1] flops = batch_size * input_dim * output_dim return flops def linear_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: """ Count flops for the aten::linear operator. """ # Inputs is a list of length 3; unlike aten::addmm, it is the first # two elements that are relevant. input_shapes = [v.shape for v in inputs[0:2]] # input_shapes[0]: [dim0, dim1, ..., input_feature_dim] # input_shapes[1]: [output_feature_dim, input_feature_dim] assert input_shapes[0][-1] == input_shapes[1][-1] flops = reduce(operator.mul, input_shapes[0]) * input_shapes[1][0] return flops def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: """ Count flops for the bmm operation. """ # Inputs should be a list of length 2. # Inputs contains the shapes of two tensor. assert len(inputs) == 2, len(inputs) input_shapes = [v.shape for v in inputs] n, c, t = input_shapes[0] d = input_shapes[-1][-1] flops = n * c * t * d return flops def conv_flop_count( x_shape: List[int], w_shape: List[int], out_shape: List[int], transposed: bool = False, ) -> Number: """ Count flops for convolution. Note only multiplication is counted. Computation for addition and bias is ignored. Flops for a transposed convolution are calculated as flops = (x_shape[2:] * prod(w_shape) * batch_size). Args: x_shape (list(int)): The input shape before convolution. w_shape (list(int)): The filter shape. out_shape (list(int)): The output shape after convolution. transposed (bool): is the convolution transposed Returns: int: the number of flops """ batch_size = x_shape[0] conv_shape = (x_shape if transposed else out_shape)[2:] flops = batch_size * reduce(operator.mul, w_shape) * reduce(operator.mul, conv_shape) return flops def conv_flop_jit(inputs: List[Any], outputs: List[Any]): """ Count flops for convolution. """ x, w = inputs[:2] x_shape, w_shape, out_shape = (x.shape, w.shape, outputs[0].shape) transposed = inputs[6] return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed) def transpose_shape(shape): return [shape[1], shape[0]] + list(shape[2:]) def conv_backward_flop_jit(inputs: List[Any], outputs: List[Any]): grad_out_shape, x_shape, w_shape = [i.shape for i in inputs[:3]] output_mask = inputs[-1] fwd_transposed = inputs[7] flop_count = 0 if output_mask[0]: grad_input_shape = outputs[0].shape flop_count += conv_flop_count(grad_out_shape, w_shape, grad_input_shape, not fwd_transposed) if output_mask[1]: grad_weight_shape = outputs[1].shape flop_count += conv_flop_count(transpose_shape(x_shape), grad_out_shape, grad_weight_shape, fwd_transposed) return flop_count def norm_flop_counter(affine_arg_index: int, input_arg_index: int) -> Callable: """ Args: affine_arg_index: index of the affine argument in inputs """ def norm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: """ Count flops for norm layers. """ # Inputs[0] contains the shape of the input. input_shape = inputs[input_arg_index].shape has_affine = inputs[affine_arg_index].shape is not None if hasattr(inputs[affine_arg_index], 'shape') else inputs[affine_arg_index] assert 2 <= len(input_shape) <= 5, input_shape # 5 is just a rough estimate flop = reduce(operator.mul, input_shape) * (5 if has_affine else 4) return flop return norm_flop_jit def batchnorm_flop_jit(inputs: List[Any], outputs: List[Any], training: bool = None) -> Number: if training is None: training = inputs[-3] assert isinstance(training, bool), "Signature of aten::batch_norm has changed!" if training: return norm_flop_counter(1, 0)(inputs, outputs) # pyre-ignore has_affine = inputs[1].shape is not None input_shape = reduce(operator.mul, inputs[0].shape) return input_shape * (2 if has_affine else 1) def ewise_flop_counter(input_scale: float = 1, output_scale: float = 0) -> Callable: """ Count flops by input_tensor.numel() * input_scale + output_tensor.numel() * output_scale Args: input_scale: scale of the input tensor (first argument) output_scale: scale of the output tensor (first element in outputs) """ def ewise_flop(inputs: List[Any], outputs: List[Any]) -> Number: ret = 0 if input_scale != 0: shape = inputs[0].shape ret += input_scale * reduce(operator.mul, shape) if shape else 0 if output_scale != 0: shape = outputs[0].shape ret += output_scale * reduce(operator.mul, shape) if shape else 0 return ret return ewise_flop def zero_flop_jit(*args): """ Count flops for zero flop layers. """ return 0 if version.parse(torch.__version__) >= version.parse('1.12.0'): flop_mapping = { # gemm aten.mm.default: matmul_flop_jit, aten.matmul.default: matmul_flop_jit, aten.addmm.default: addmm_flop_jit, aten.bmm.default: bmm_flop_jit, # convolution aten.convolution.default: conv_flop_jit, aten._convolution.default: conv_flop_jit, aten.convolution_backward.default: conv_backward_flop_jit, # normalization aten.native_batch_norm.default: batchnorm_flop_jit, aten.native_batch_norm_backward.default: batchnorm_flop_jit, aten.cudnn_batch_norm.default: batchnorm_flop_jit, aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True), aten.native_layer_norm.default: norm_flop_counter(2, 0), aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), # pooling aten.avg_pool1d.default: ewise_flop_counter(1, 0), aten.avg_pool2d.default: ewise_flop_counter(1, 0), aten.avg_pool2d_backward.default: ewise_flop_counter(0, 1), aten.avg_pool3d.default: ewise_flop_counter(1, 0), aten.avg_pool3d_backward.default: ewise_flop_counter(0, 1), aten.max_pool1d.default: ewise_flop_counter(1, 0), aten.max_pool2d.default: ewise_flop_counter(1, 0), aten.max_pool3d.default: ewise_flop_counter(1, 0), aten.max_pool1d_with_indices.default: ewise_flop_counter(1, 0), aten.max_pool2d_with_indices.default: ewise_flop_counter(1, 0), aten.max_pool2d_with_indices_backward.default: ewise_flop_counter(0, 1), aten.max_pool3d_with_indices.default: ewise_flop_counter(1, 0), aten.max_pool3d_with_indices_backward.default: ewise_flop_counter(0, 1), aten._adaptive_avg_pool2d.default: ewise_flop_counter(1, 0), aten._adaptive_avg_pool2d_backward.default: ewise_flop_counter(0, 1), aten._adaptive_avg_pool3d.default: ewise_flop_counter(1, 0), aten._adaptive_avg_pool3d_backward.default: ewise_flop_counter(0, 1), aten.embedding_dense_backward.default: ewise_flop_counter(0, 1), aten.embedding.default: ewise_flop_counter(1, 0), } ewise_flop_aten = [ # basic op aten.add.Tensor, aten.add_.Tensor, aten.div.Tensor, aten.div_.Tensor, aten.div.Scalar, aten.div_.Scalar, aten.mul.Tensor, aten.mul.Scalar, aten.mul_.Tensor, aten.neg.default, aten.pow.Tensor_Scalar, aten.rsub.Scalar, aten.sum.default, aten.sum.dim_IntList, aten.mean.dim, # activation op aten.hardswish.default, aten.hardswish_.default, aten.hardswish_backward.default, aten.hardtanh.default, aten.hardtanh_.default, aten.hardtanh_backward.default, aten.hardsigmoid_backward.default, aten.hardsigmoid.default, aten.gelu.default, aten.gelu_backward.default, aten.silu.default, aten.silu_.default, aten.silu_backward.default, aten.sigmoid.default, aten.sigmoid_backward.default, aten._softmax.default, aten._softmax_backward_data.default, aten.relu_.default, aten.relu.default, aten.tanh.default, aten.tanh_backward.default, aten.threshold_backward.default, # dropout aten.native_dropout.default, aten.native_dropout_backward.default, # distribution aten.bernoulli_.float, # where aten.where.self, ] for op in ewise_flop_aten: flop_mapping[op] = ewise_flop_counter(1, 0) # fix-me: this will be removed in future zero_flop_aten = [ aten.as_strided.default, aten.as_strided_.default, aten.cat.default, aten.clone.default, aten.copy_.default, aten.detach.default, aten.expand.default, aten.empty_like.default, aten.new_empty.default, aten.new_empty_strided.default, aten.ones_like.default, aten._reshape_alias.default, aten.select.int, aten.select_backward.default, aten.squeeze.dim, aten.slice.Tensor, aten.slice_backward.default, aten.split.Tensor, aten.permute.default, aten.t.default, aten.transpose.int, aten._to_copy.default, aten.unsqueeze.default, aten.unbind.int, aten._unsafe_view.default, aten.view.default, aten.zero_.default, aten.zeros_like.default, ] for op in zero_flop_aten: flop_mapping[op] = zero_flop_jit else: flop_mapping = {} elementwise_flop_aten = {} zero_flop_aten = {}