From 09c023bee268268adb4b04896222c54dcd3e7a2c Mon Sep 17 00:00:00 2001 From: Super Daniel <78588128+super-dainiu@users.noreply.github.com> Date: Thu, 25 Aug 2022 23:11:13 +0800 Subject: [PATCH] [fx] add more op patches for profiler and error message for unsupported ops. (#1495) * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] merge development into main (#1) * [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen. * [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. * [fx] fix lowercase naming conventions. * [fx] simplify test for ckpt. * [fx] add rules to linearize computation graphs for searching. (#2) * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] merge development into main (#1) * [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen. * [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. * [fx] fix lowercase naming conventions. * [fx] simplify test for ckpt. * [fx] fix test and algorithm bugs in activation checkpointing. * [fx] polish ckpt_test. * [fx] add rules to linearize computation graphs for searching. * [fx] remove chen_sqrt for sake of simplicity * [fx] remove chen_sqrt for sake of simplicity * [fx] remove chen_sqrt for sake of simplicity * [fx] remove chen_sqrt for sake of simplicity * [fx] fix inconsistencies. * [fx] fix MetaInfoProp. * [fx] fix MetaInfoProp. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] fix error in tests. * [fx] unfix bug. * [fx] unfix bug. * [fx] patch more modules and functions. * [fx] change name of utils.py to profiler.py * [fx] add profiler for rnn. * [fx] add profiler for rnn. * [fx] polish and add more patch for profiler. * [fx] polish and add more patch for profiler. --- colossalai/fx/profiler/__init__.py | 2 +- .../fx/profiler/{utils.py => profiler.py} | 112 +++++++++++++----- .../profiler_function/activation_function.py | 4 + .../profiler/profiler_function/arithmetic.py | 26 ++-- .../profiler/profiler_function/python_ops.py | 7 ++ .../profiler/profiler_function/torch_ops.py | 12 +- .../fx/profiler/profiler_module/__init__.py | 3 + .../profiler_module/activation_function.py | 4 + .../fx/profiler/profiler_module/attention.py | 81 +++++++++++++ .../profiler/profiler_module/convolution.py | 41 +++---- .../fx/profiler/profiler_module/dropout.py | 11 ++ .../fx/profiler/profiler_module/linear.py | 3 +- colossalai/fx/profiler/profiler_module/rnn.py | 70 ++++++++++- .../fx/profiler/profiler_module/torch_op.py | 11 ++ 14 files changed, 310 insertions(+), 77 deletions(-) rename colossalai/fx/profiler/{utils.py => profiler.py} (55%) create mode 100644 colossalai/fx/profiler/profiler_module/attention.py create mode 100644 colossalai/fx/profiler/profiler_module/dropout.py create mode 100644 colossalai/fx/profiler/profiler_module/torch_op.py diff --git a/colossalai/fx/profiler/__init__.py b/colossalai/fx/profiler/__init__.py index 353c600c7..a56b0dc69 100644 --- a/colossalai/fx/profiler/__init__.py +++ b/colossalai/fx/profiler/__init__.py @@ -1,4 +1,4 @@ from .registry import * from .profiler_function import * from .profiler_module import * -from .utils import * +from .profiler import * diff --git a/colossalai/fx/profiler/utils.py b/colossalai/fx/profiler/profiler.py similarity index 55% rename from colossalai/fx/profiler/utils.py rename to colossalai/fx/profiler/profiler.py index 5024acb50..e8e641412 100644 --- a/colossalai/fx/profiler/utils.py +++ b/colossalai/fx/profiler/profiler.py @@ -1,8 +1,8 @@ from functools import partial from operator import add, floordiv, getitem, mul, neg, setitem, sub, pos -from typing import Callable, NamedTuple, Any, Dict, Tuple +from typing import Callable, List, NamedTuple, Any, Dict, Tuple, Union import torch -from torch.fx.node import Argument, Target +from torch.fx.node import Argument, Target, map_aggregate from torch.fx._compatibility import compatibility from colossalai.fx.tracer.meta_patch import meta_patched_function, meta_patched_module from . import meta_profiler_function, meta_profiler_module @@ -12,6 +12,30 @@ __all__ = [ 'calculate_param_size' ] +CALL_FUNCTION_MSG = \ +""" +Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n +from colossalai.fx.profiler import meta_profiler_function + +@meta_profiler_function.register(YOUR_FUNCTION) +def profile_YOUR_FUNCTION(input: torch.Tensor, *args) -> Tuple[int, int]: + flops = ... + macs = ... + return flops, macs +""" +CALL_METHOD_MSG = 'Please check if {} is an inplace method. If so, add target to INPLACE_METHOD={}. Otherwise, add target to NON_INPLACE_METHOD={}' +CALL_MODULE_MSG = \ +""" +Colossal-AI hasn't supported profiling for {}, you might manually patch it with the following code.\n +from colossalai.fx.profiler import meta_profiler_module + +@meta_profiler_module.register(YOUR_MODULE) +def profile_YOUR_MODULE(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]: + flops = ... + macs = ... + return flops, macs +""" + # TODO fill out the inplace ops INPLACE_OPS = [ add, @@ -22,18 +46,30 @@ INPLACE_OPS = [ pos, getitem, setitem, + getattr, torch.Tensor.cpu, ] -# TODO check that call_methods are indeed inplace +# TODO: list all call_methods that are inplace here INPLACE_METHOD = [ 'transpose', 'permute', + # TODO: reshape may return a copy of the data if the data is not contiguous + 'reshape', + 'dim', + 'flatten', +] + +# TODO: list all call_methods that are not inplace here +NON_INPLACE_METHOD = [ + 'expand', + 'mean', ] @compatibility(is_backward_compatible=True) class MetaProfile(NamedTuple): + # MetaProfile is a structure containing pertinent information # about a node within a torch.fx GraphModule. @@ -43,9 +79,14 @@ class MetaProfile(NamedTuple): macs: int -def calculate_activation_size(activation: any) -> int: - """ - Calculate activation size of a node. +def calculate_activation_size(activation: Union[torch.Tensor, Dict, List, Tuple, int]) -> int: + """Calculate activation size of a node. + + Args: + activation (Union[torch.Tensor, Dict, List, Tuple, int]): The activation of a `torch.nn.Module` or `torch.nn.functional` + + Returns: + int: The activation size """ activation_size = 0 if isinstance(activation, torch.Tensor): @@ -53,15 +94,20 @@ def calculate_activation_size(activation: any) -> int: elif isinstance(activation, dict): value_list = [v for _, v in activation.items()] activation_size += calculate_activation_size(value_list) - else: + elif isinstance(activation, tuple) or isinstance(activation, list): for element in activation: activation_size += calculate_activation_size(element) return activation_size def calculate_param_size(mod: torch.nn.Module) -> int: - """ - Calculate param size of a node. + """Calculate param size of a node. + + Args: + mod (torch.nn.Module): The target `torch.nn.Module` + + Returns: + int: The param size """ param_size = 0 for param in mod.parameters(): @@ -78,17 +124,21 @@ def profile_function(target: 'Target') -> Callable: You may only use tensors with `device=meta` for this wrapped function. Only original `torch.nn.functional` are available. - Usage: - input = torch.rand(100, 100, 100, 100, device='meta') - func = torch.nn.functional.relu - output, profile = profile_function(func)(input, inplace=False) - print(f"Profiling function {func},") - print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs") + Examples: + >> input = torch.rand(100, 100, 100, 100, device='meta') + >> func = torch.nn.functional.relu + >> output, profile = profile_function(func)(input, inplace=False) + >> print(f"Profiling function {func},") + >> print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs") + Profiling function , + Param size: 0.000 MB, Activation size: 381.470 MB, 100000000 FLOPs, 0 MACs """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: assert meta_profiler_function.has(target) or meta_profiler_function.has( - target.__name__), f"Colossal-AI hasn't supported profiling for {target}, you might manually patch it." + target.__name__), CALL_FUNCTION_MSG.format(target) + # ensure all arguments satisfy `device='meta'` + args, kwargs = map_aggregate([args, kwargs], lambda a: a.to('meta') if isinstance(a, torch.Tensor) else a) # call_function has no parameters param_size = 0 @@ -127,14 +177,17 @@ def profile_method(target: 'Target') -> Callable: # args[0] is the `self` object for this method call self_obj, *args_tail = args - # Execute the method and return the result + # execute the method and return the result assert isinstance(target, str), f'{target} instance is not str.' - result = getattr(self_obj, target)(*args_tail, **kwargs) - assert target in INPLACE_METHOD, f'Please check {target} is an inplace method. If so, add target to INPLACE_METHOD={INPLACE_METHOD}.' + # ensure all arguments satisfy `device='meta'` + map_aggregate([args, kwargs], lambda a: a.to('meta') if isinstance(a, torch.Tensor) else a) + result = getattr(self_obj, target)(*args_tail, **kwargs) + assert target in INPLACE_METHOD + NON_INPLACE_METHOD, CALL_METHOD_MSG.format( + target, INPLACE_METHOD, NON_INPLACE_METHOD) # call_method has no parameters and are MOSTLY(?) inplace, and has no FLOPs or MACs. param_size = 0 - activation_size = 0 + activation_size = 0 if target in INPLACE_METHOD else calculate_activation_size(result) flops = 0 macs = 0 return result, MetaProfile(param_size, activation_size, flops, macs) @@ -151,17 +204,20 @@ def profile_module(module: torch.nn.Module) -> Callable: You may only use tensors with `device=meta` for this wrapped function. Only original `torch.nn` are available. - Usage: - input = torch.rand(4, 3, 224, 224, device='meta') - mod = torch.nn.Conv2d(3, 128, 3) - output, profile = profile_module(mod)(input) - print(f"Profiling function {mod},") - print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs") + Example: + >> input = torch.rand(4, 3, 224, 224, device='meta') + >> mod = torch.nn.Conv2d(3, 128, 3) + >> output, profile = profile_module(mod)(input) + >> print(f"Profiling function {mod},") + >> print(f"Param size: {profile.param / 1024**2:.3f} MB, Activation size: {profile.activation / 1024**2:.3f} MB, {profile.flops} FLOPs, {profile.macs} MACs") + Profiling function Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1)), + Param size: 0.014 MB, Activation size: 96.258 MB, 1387837440 FLOPs, 681302016 MACs """ def f(*args: Tuple[Argument, ...], **kwargs: Dict[str, Any]) -> Any: - assert meta_profiler_module.has( - type(module)), f"Colossal-AI hasn't supported profiling for {module}, you might manually patch it." + assert meta_profiler_module.has(type(module)), CALL_MODULE_MSG.format(type(module)) + # ensure all arguments satisfy `device='meta'` + map_aggregate([args, kwargs], lambda a: a.to('meta') if isinstance(a, torch.Tensor) else a) param_size = calculate_param_size(module) activation_size = 0 result = func(*args, **kwargs) diff --git a/colossalai/fx/profiler/profiler_function/activation_function.py b/colossalai/fx/profiler/profiler_function/activation_function.py index 0bf5d8af9..a43aef063 100644 --- a/colossalai/fx/profiler/profiler_function/activation_function.py +++ b/colossalai/fx/profiler/profiler_function/activation_function.py @@ -12,6 +12,8 @@ _multiplier = { torch.nn.functional.elu: 4, torch.nn.functional.relu6: 2, torch.nn.functional.gelu: 9, + torch.nn.functional.hardswish: 5, + torch.nn.functional.hardsigmoid: 4, } @@ -23,6 +25,8 @@ _multiplier = { @meta_profiler_function.register(torch.nn.functional.relu) @meta_profiler_function.register(torch.nn.functional.sigmoid) @meta_profiler_function.register(torch.nn.functional.tanh) +@meta_profiler_function.register(torch.nn.functional.hardswish) +@meta_profiler_function.register(torch.nn.functional.hardsigmoid) def torch_nn_func_non_linear_act(input: torch.Tensor, inplace: bool = False) -> Tuple[int, int]: flops = input.numel() macs = 0 diff --git a/colossalai/fx/profiler/profiler_function/arithmetic.py b/colossalai/fx/profiler/profiler_function/arithmetic.py index b52c56557..2cf50133d 100644 --- a/colossalai/fx/profiler/profiler_function/arithmetic.py +++ b/colossalai/fx/profiler/profiler_function/arithmetic.py @@ -1,24 +1,19 @@ +import operator +from functools import reduce from typing import Any, Optional, Tuple, Union import torch from ..registry import meta_profiler_function -def _prod(dims): - p = 1 - for v in dims: - p *= v - return p - - def _elementwise_flops_compute(input, other): # copied from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L763 if not torch.is_tensor(input): if torch.is_tensor(other): - return _prod(other.shape), 0 + return reduce(operator.mul, other.shape), 0 else: return 1, 0 elif not torch.is_tensor(other): - return _prod(input.shape), 0 + return reduce(operator.mul, input.shape), 0 else: dim_input = len(input.shape) dim_other = len(other.shape) @@ -32,17 +27,24 @@ def _elementwise_flops_compute(input, other): final_shape.append(in_i) else: final_shape.append(ot_i) - flops = _prod(final_shape) + flops = reduce(operator.mul, final_shape) return flops, 0 @meta_profiler_function.register(torch.add) +@meta_profiler_function.register(torch.eq) +@meta_profiler_function.register(torch.sub) +@meta_profiler_function.register(torch.mul) +@meta_profiler_function.register(torch.floor_divide) @meta_profiler_function.register('add') # for built-in op + @meta_profiler_function.register('iadd') # for built-in op += +@meta_profiler_function.register('eq') # for built-in op = @meta_profiler_function.register('sub') # for built-in op - @meta_profiler_function.register('isub') # for built-in op -= @meta_profiler_function.register('mul') # for built-in op * @meta_profiler_function.register('imul') # for built-in op *= +@meta_profiler_function.register('floordiv') # for built-in op // +@meta_profiler_function.register('ifloordiv') # for built-in op //= def torch_add_like_ops(input: Any, other: Any, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: return _elementwise_flops_compute(input, other) @@ -58,14 +60,14 @@ def torch_elementwise_op(input: torch.Tensor, *, out: Optional[torch.Tensor] = N @meta_profiler_function.register('matmul') # for built-in op @ @meta_profiler_function.register(torch.Tensor.matmul) def torch_matmul(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: - macs = _prod(input.shape) * other.shape[-1] + macs = reduce(operator.mul, input.shape) * other.shape[-1] flops = 2 * macs return flops, macs @meta_profiler_function.register(torch.bmm) def torch_bmm(input: torch.Tensor, other: torch.Tensor, *, out: Optional[torch.Tensor] = None) -> Tuple[int, int]: - macs = _prod(input.shape) * other.shape[-1] + macs = reduce(operator.mul, input.shape) * other.shape[-1] flops = 2 * macs return flops, macs diff --git a/colossalai/fx/profiler/profiler_function/python_ops.py b/colossalai/fx/profiler/profiler_function/python_ops.py index 95c67c47e..15e8aa675 100644 --- a/colossalai/fx/profiler/profiler_function/python_ops.py +++ b/colossalai/fx/profiler/profiler_function/python_ops.py @@ -10,3 +10,10 @@ def operator_getitem(a: Any, b: Any) -> Tuple[int, int]: flops = 0 macs = 0 return flops, macs + + +@meta_profiler_function.register(getattr) +def python_getattr(a: Any, b: Any) -> Tuple[int, int]: + flops = 0 + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/profiler_function/torch_ops.py b/colossalai/fx/profiler/profiler_function/torch_ops.py index f67090b23..abdd7ad56 100644 --- a/colossalai/fx/profiler/profiler_function/torch_ops.py +++ b/colossalai/fx/profiler/profiler_function/torch_ops.py @@ -1,15 +1,10 @@ +from functools import reduce +import operator from typing import Any, Optional, Tuple import torch from ..registry import meta_profiler_function -def _prod(dims): - p = 1 - for v in dims: - p *= v - return p - - @meta_profiler_function.register(torch.arange) @meta_profiler_function.register(torch.finfo) @meta_profiler_function.register(torch.permute) @@ -31,6 +26,7 @@ def _prod(dims): @meta_profiler_function.register(torch.full) @meta_profiler_function.register(torch.Tensor.cpu) @meta_profiler_function.register(torch.Tensor.cuda) +@meta_profiler_function.register(torch._assert) def torch_zero_flops_op(*args, **kwargs) -> Tuple[int, int]: flops = 0 macs = 0 @@ -57,7 +53,7 @@ def torch_max(input: torch.Tensor, if dim is not None: shape = list(input.shape) shape.pop(int(dim)) - flops = _prod(shape), macs + flops = reduce(operator.mul, shape), macs return flops, macs else: flops = input.numel() diff --git a/colossalai/fx/profiler/profiler_module/__init__.py b/colossalai/fx/profiler/profiler_module/__init__.py index 3f40ec2a6..e4fe646f3 100644 --- a/colossalai/fx/profiler/profiler_module/__init__.py +++ b/colossalai/fx/profiler/profiler_module/__init__.py @@ -1,7 +1,10 @@ from .activation_function import * +from .attention import * from .convolution import * +from .dropout import * from .embedding import * from .linear import * from .normalization import * from .pooling import * from .rnn import * +from .torch_op import * diff --git a/colossalai/fx/profiler/profiler_module/activation_function.py b/colossalai/fx/profiler/profiler_module/activation_function.py index 1008eef0a..2ebf514ad 100644 --- a/colossalai/fx/profiler/profiler_module/activation_function.py +++ b/colossalai/fx/profiler/profiler_module/activation_function.py @@ -12,6 +12,8 @@ _multiplier = { torch.nn.ELU: 4, torch.nn.ReLU6: 2, torch.nn.GELU: 9, + torch.nn.Hardswish: 5, + torch.nn.Hardsigmoid: 4, } @@ -23,6 +25,8 @@ _multiplier = { @meta_profiler_module.register(torch.nn.Tanh) @meta_profiler_module.register(torch.nn.ReLU6) @meta_profiler_module.register(torch.nn.PReLU) +@meta_profiler_module.register(torch.nn.Hardswish) +@meta_profiler_module.register(torch.nn.Hardsigmoid) def torch_nn_non_linear_act(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]: flops = input.numel() macs = 0 diff --git a/colossalai/fx/profiler/profiler_module/attention.py b/colossalai/fx/profiler/profiler_module/attention.py new file mode 100644 index 000000000..8daf74b23 --- /dev/null +++ b/colossalai/fx/profiler/profiler_module/attention.py @@ -0,0 +1,81 @@ +from typing import Optional, Tuple +import torch +from ..registry import meta_profiler_module + + +# TODO: This is hard to compute memory cost +@meta_profiler_module.register(torch.nn.MultiheadAttention) +def torch_nn_msa(self: torch.nn.MultiheadAttention, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_padding_mask: Optional[torch.Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[torch.Tensor] = None, + average_attn_weights: bool = True) -> Tuple[int, int]: + if getattr(self, 'batch_first', False): + batch_size = query.shape[0] + len_idx = 1 + else: + batch_size = query.shape[1] + len_idx = 0 + dim_idx = 2 + + qdim = query.shape[dim_idx] + kdim = key.shape[dim_idx] + vdim = value.shape[dim_idx] + + qlen = query.shape[len_idx] + klen = key.shape[len_idx] + vlen = value.shape[len_idx] + + num_heads = self.num_heads + assert qdim == self.embed_dim + + if self.kdim is None: + assert kdim == qdim + if self.vdim is None: + assert vdim == qdim + + flops = 0 + macs = 0 + + # Q scaling + flops += qlen * qdim + + # Initial projections + flops += 2 * ((qlen * qdim * qdim) # QW + + (klen * kdim * kdim) # KW + + (vlen * vdim * vdim) # VW + ) + + macs += ((qlen * qdim * qdim) # QW + + (klen * kdim * kdim) # KW + + (vlen * vdim * vdim) # VW + ) + + if self.in_proj_bias is not None: + flops += (qlen + klen + vlen) * qdim + + # attention heads: scale, matmul, softmax, matmul + qk_head_dim = qdim // num_heads + v_head_dim = vdim // num_heads + + head_flops = ( + 2 * (qlen * klen * qk_head_dim) # QK^T + + (qlen * klen) # softmax + + 2 * (qlen * klen * v_head_dim) # AV + ) + head_macs = ((qlen * klen * qk_head_dim) # QK^T + + 2 * (qlen * klen * v_head_dim) # AV + ) + + flops += num_heads * head_flops + macs += num_heads * head_flops + + # final projection, bias is always enabled + flops += qlen * vdim * (vdim + 1) + + flops *= batch_size + macs *= batch_size + return flops, macs diff --git a/colossalai/fx/profiler/profiler_module/convolution.py b/colossalai/fx/profiler/profiler_module/convolution.py index 12e3d7e2f..3193489fe 100644 --- a/colossalai/fx/profiler/profiler_module/convolution.py +++ b/colossalai/fx/profiler/profiler_module/convolution.py @@ -1,16 +1,11 @@ +import operator +from functools import reduce import math from typing import Tuple import torch from ..registry import meta_profiler_module -def _prod(dims): - p = 1 - for v in dims: - p *= v - return p - - @meta_profiler_module.register(torch.nn.Conv1d) def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, int]: # the output shape is calculated using the formula stated @@ -23,8 +18,8 @@ def torch_nn_conv1d(self: torch.nn.Conv1d, input: torch.Tensor) -> Tuple[int, in c_out, l_out, ) - macs_per_elem = _prod(self.kernel_size) * c_in // self.groups - num_elem = _prod(result_shape) + macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups + num_elem = reduce(operator.mul, result_shape) macs = macs_per_elem * num_elem flops = 2 * macs if self.bias is not None: @@ -47,8 +42,8 @@ def torch_nn_conv2d(self: torch.nn.Conv2d, input: torch.Tensor) -> Tuple[int, in h_out, w_out, ) - macs_per_elem = _prod(self.kernel_size) * c_in // self.groups - num_elem = _prod(result_shape) + macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups + num_elem = reduce(operator.mul, result_shape) macs = macs_per_elem * num_elem flops = 2 * macs if self.bias is not None: @@ -74,8 +69,8 @@ def torch_nn_conv3d(self: torch.nn.Conv3d, input: torch.Tensor) -> Tuple[int, in h_out, w_out, ) - macs_per_elem = _prod(self.kernel_size) * c_in // self.groups - num_elem = _prod(result_shape) + macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups + num_elem = reduce(operator.mul, result_shape) macs = macs_per_elem * num_elem flops = 2 * macs if self.bias is not None: @@ -95,14 +90,14 @@ def torch_nn_convtranspose1d(self: torch.nn.ConvTranspose1d, input: torch.Tensor c_out, l_out, ) - macs_per_elem = _prod(self.kernel_size) * c_in // self.groups - num_elem = _prod( - input.shape + macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups + num_elem = reduce( + operator.mul, input.shape ) # see https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/profiling/flops_profiler/profiler.py#L604 macs = macs_per_elem * num_elem flops = 2 * macs if self.bias is not None: - flops += _prod(result_shape) + flops += reduce(operator.mul, result_shape) return flops, macs @@ -121,12 +116,12 @@ def torch_nn_convtranspose2d(self: torch.nn.ConvTranspose2d, input: torch.Tensor h_out, w_out, ) - macs_per_elem = _prod(self.kernel_size) * c_in // self.groups - num_elem = _prod(input.shape) + macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups + num_elem = reduce(operator.mul, input.shape) macs = macs_per_elem * num_elem flops = 2 * macs if self.bias is not None: - flops += _prod(result_shape) + flops += reduce(operator.mul, result_shape) return flops, macs @@ -148,10 +143,10 @@ def torch_nn_convtranspose3d(self: torch.nn.ConvTranspose3d, input: torch.Tensor h_out, w_out, ) - macs_per_elem = _prod(self.kernel_size) * c_in // self.groups - num_elem = _prod(input.shape) + macs_per_elem = reduce(operator.mul, self.kernel_size) * c_in // self.groups + num_elem = reduce(operator.mul, input.shape) macs = macs_per_elem * num_elem flops = 2 * macs if self.bias is not None: - flops += _prod(result_shape) + flops += reduce(operator.mul, result_shape) return flops, macs diff --git a/colossalai/fx/profiler/profiler_module/dropout.py b/colossalai/fx/profiler/profiler_module/dropout.py new file mode 100644 index 000000000..417e0ed46 --- /dev/null +++ b/colossalai/fx/profiler/profiler_module/dropout.py @@ -0,0 +1,11 @@ +from typing import Tuple +import torch +from ..registry import meta_profiler_module + + +@meta_profiler_module.register(torch.nn.Dropout) +def torch_nn_dropout(self: torch.nn.Module, input: torch.Tensor) -> Tuple[int, int]: + # nn.Embedding is a dictionary lookup, so technically it has 0 FLOPs. (https://discuss.pytorch.org/t/correct-way-to-calculate-flops-in-model/67198/6) + flops = 0 + macs = 0 + return flops, macs diff --git a/colossalai/fx/profiler/profiler_module/linear.py b/colossalai/fx/profiler/profiler_module/linear.py index f657f9ac7..e1ffb6f24 100644 --- a/colossalai/fx/profiler/profiler_module/linear.py +++ b/colossalai/fx/profiler/profiler_module/linear.py @@ -4,9 +4,10 @@ from ..registry import meta_profiler_module @meta_profiler_module.register(torch.nn.Linear) +@meta_profiler_module.register(torch.nn.modules.linear.NonDynamicallyQuantizableLinear) def torch_nn_linear(self: torch.nn.Linear, input: torch.Tensor) -> Tuple[int, int]: out_features = self.weight.shape[0] - macs = torch.numel(input) * out_features + macs = input.numel() * out_features flops = 2 * macs if self.bias is not None: flops += self.bias.numel() diff --git a/colossalai/fx/profiler/profiler_module/rnn.py b/colossalai/fx/profiler/profiler_module/rnn.py index c042458b3..6e733d6da 100644 --- a/colossalai/fx/profiler/profiler_module/rnn.py +++ b/colossalai/fx/profiler/profiler_module/rnn.py @@ -1,13 +1,75 @@ +from functools import reduce +import operator import torch from ..registry import meta_profiler_module -from typing import Optional, Tuple +from typing import Optional, Tuple, Union -# TODO: calculate rnn FLOPs +def _rnn_flops(flops: int, macs: int, module: torch.nn.RNNBase, w_ih: torch.Tensor, + w_hh: torch.Tensor) -> Tuple[int, int]: + # copied from https://github.com/sovrasov/flops-counter.pytorch/blob/master/ptflops/pytorch_ops.py + + # matrix matrix mult ih state and internal state + macs += reduce(operator.mul, w_ih.shape) + flops += 2 * reduce(operator.mul, w_ih.shape) + # matrix matrix mult hh state and internal state + macs += reduce(operator.mul, w_hh.shape) + flops += 2 * reduce(operator.mul, w_hh.shape) + if isinstance(module, (torch.nn.RNN, torch.nn.RNNCell)): + # add both operations + flops += module.hidden_size + elif isinstance(module, (torch.nn.GRU, torch.nn.GRUCell)): + # hadamard of r + flops += module.hidden_size + # adding operations from both states + flops += module.hidden_size * 3 + # last two hadamard product and add + flops += module.hidden_size * 3 + elif isinstance(module, (torch.nn.LSTM, torch.nn.LSTMCell)): + # adding operations from both states + flops += module.hidden_size * 4 + # two hadamard product and add for C state + flops += module.hidden_size * 3 + # final hadamard + flops += module.hidden_size * 3 + return flops, macs + + +@meta_profiler_module.register(torch.nn.LSTM) @meta_profiler_module.register(torch.nn.GRU) @meta_profiler_module.register(torch.nn.RNN) -def torch_nn_rnn(self: torch.nn.Module, input: torch.Tensor, hx: torch.Tensor) -> Tuple[int, int]: - raise NotImplementedError +def torch_nn_rnn(self: torch.nn.RNNBase, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[int, int]: flops = 0 macs = 0 + for i in range(self.num_layers): + w_ih = self.__getattr__('weight_ih_l' + str(i)) + w_hh = self.__getattr__('weight_hh_l' + str(i)) + flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh) + if self.bias: + b_ih = self.__getattr__('bias_ih_l' + str(i)) + b_hh = self.__getattr__('bias_hh_l' + str(i)) + flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh) + flops *= reduce(operator.mul, input.shape[:2]) + macs *= reduce(operator.mul, input.shape[:2]) + if self.bidirectional: + flops *= 2 + macs *= 2 + return flops, macs + + +@meta_profiler_module.register(torch.nn.LSTMCell) +@meta_profiler_module.register(torch.nn.GRUCell) +@meta_profiler_module.register(torch.nn.RNNCell) +def torch_nn_rnn(self: torch.nn.RNNCellBase, input: torch.Tensor, hx: Optional[torch.Tensor] = None) -> Tuple[int, int]: + flops = 0 + macs = 0 + w_ih = self.__getattr__('weight_ih_l') + w_hh = self.__getattr__('weight_hh_l') + flops, macs = _rnn_flops(flops, macs, self, w_ih, w_hh) + if self.bias: + b_ih = self.__getattr__('bias_ih_l') + b_hh = self.__getattr__('bias_hh_l') + flops += reduce(operator.mul, b_ih) + reduce(operator.mul, b_hh) + flops *= input.shape[0] + macs *= input.shape[0] return flops, macs diff --git a/colossalai/fx/profiler/profiler_module/torch_op.py b/colossalai/fx/profiler/profiler_module/torch_op.py new file mode 100644 index 000000000..d3aed874e --- /dev/null +++ b/colossalai/fx/profiler/profiler_module/torch_op.py @@ -0,0 +1,11 @@ +import operator +import torch +from ..registry import meta_profiler_module +from typing import Optional, Tuple, Union + + +@meta_profiler_module.register(torch.nn.Flatten) +def torch_nn_flatten(self: torch.nn.Flatten, input: torch.Tensor) -> Tuple[int, int]: + flops = 0 + macs = 0 + return flops, macs