From fa3d66feb9793a0e0003d827066a70fabe924a50 Mon Sep 17 00:00:00 2001 From: oahzxl <43881818+oahzxl@users.noreply.github.com> Date: Thu, 2 Feb 2023 16:19:26 +0800 Subject: [PATCH] support unet metainfo prop (#2544) --- colossalai/fx/_meta_registrations.py | 31 +++++++++++++--------------- colossalai/fx/profiler/opcount.py | 21 +++++++++++++++++++ 2 files changed, 35 insertions(+), 17 deletions(-) diff --git a/colossalai/fx/_meta_registrations.py b/colossalai/fx/_meta_registrations.py index 8c0201c71..153214447 100644 --- a/colossalai/fx/_meta_registrations.py +++ b/colossalai/fx/_meta_registrations.py @@ -164,18 +164,9 @@ def meta_conv( @register_meta(aten._convolution.default) -def meta_conv_1( - input_tensor: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - stride: List[int], - padding: List[int], - dilation: List[int], - is_transposed: bool, - output_padding: List[int], - groups: int, - *extra_args -): +def meta_conv_1(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int], + padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int, + *extra_args): out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups) return out @@ -233,11 +224,8 @@ def meta_cuda_rnn( if is_input_packed: out_shape = [batch_sizes_sum, out_size * num_directions] else: - out_shape = ( - [mini_batch, seq_length, out_size * num_directions] - if batch_first - else [seq_length, mini_batch, out_size * num_directions] - ) + out_shape = ([mini_batch, seq_length, out_size * + num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions]) output = input.new_empty(out_shape) cell_shape = [num_layers * num_directions, mini_batch, hidden_size] @@ -372,6 +360,15 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me return dX, dgamma, dbeta +# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/group_norm.cpp +@register_meta(aten.native_group_norm_backward.default) +def meta_gn_backward(dY: torch.Tensor, input: torch.Tensor, mean, rstd, gamma, N, C, HxW, group, grad_input_mask): + dX = torch.empty_like(input) + dgamma = torch.empty_like(gamma) + dbeta = torch.empty_like(gamma) + return dX, dgamma, dbeta + + # ================================== Misc ========================================== # https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml @register_meta(aten.roll.default) diff --git a/colossalai/fx/profiler/opcount.py b/colossalai/fx/profiler/opcount.py index 6bd612ad2..d780ef6d4 100644 --- a/colossalai/fx/profiler/opcount.py +++ b/colossalai/fx/profiler/opcount.py @@ -70,6 +70,19 @@ def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: return flops +def baddbmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number: + """ + Count flops for the baddbmm(batch add and batch matmul) operation. + """ + # Inputs = [input, batch1, batch2] + # out = input + batch1 x batch2 + assert len(inputs) == 3, len(inputs) + n, c, t = inputs[1].shape + d = inputs[2].shape[-1] + flops = n * c * t * d + return flops + + def conv_flop_count( x_shape: List[int], w_shape: List[int], @@ -196,6 +209,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): aten.matmul.default: matmul_flop_jit, aten.addmm.default: addmm_flop_jit, aten.bmm.default: bmm_flop_jit, + aten.baddbmm.default: baddbmm_flop_jit, # convolution aten.convolution.default: conv_flop_jit, @@ -209,6 +223,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True), aten.native_layer_norm.default: norm_flop_counter(2, 0), aten.native_layer_norm_backward.default: norm_flop_counter(2, 0), + aten.native_group_norm.default: norm_flop_counter(2, 0), + aten.native_group_norm_backward.default: norm_flop_counter(2, 0), # pooling aten.avg_pool1d.default: elementwise_flop_counter(1, 0), @@ -230,6 +246,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1), aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1), aten.embedding.default: elementwise_flop_counter(1, 0), + aten.upsample_nearest2d.vec: elementwise_flop_counter(0, 1), + aten.upsample_nearest2d_backward.vec: elementwise_flop_counter(0, 1), } elementwise_flop_aten = [ @@ -251,6 +269,9 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'): aten.mean.dim, aten.sub.Tensor, aten.sub_.Tensor, + aten.exp.default, + aten.sin.default, + aten.cos.default, # activation op aten.hardswish.default,