support unet metainfo prop (#2544)

pull/2492/head
oahzxl 2023-02-02 16:19:26 +08:00 committed by GitHub
parent c4b15661d7
commit fa3d66feb9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 17 deletions

View File

@ -164,18 +164,9 @@ def meta_conv(
@register_meta(aten._convolution.default)
def meta_conv_1(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: List[int],
padding: List[int],
dilation: List[int],
is_transposed: bool,
output_padding: List[int],
groups: int,
*extra_args
):
def meta_conv_1(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
*extra_args):
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
return out
@ -233,11 +224,8 @@ def meta_cuda_rnn(
if is_input_packed:
out_shape = [batch_sizes_sum, out_size * num_directions]
else:
out_shape = (
[mini_batch, seq_length, out_size * num_directions]
if batch_first
else [seq_length, mini_batch, out_size * num_directions]
)
out_shape = ([mini_batch, seq_length, out_size *
num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
output = input.new_empty(out_shape)
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
@ -372,6 +360,15 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me
return dX, dgamma, dbeta
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/group_norm.cpp
@register_meta(aten.native_group_norm_backward.default)
def meta_gn_backward(dY: torch.Tensor, input: torch.Tensor, mean, rstd, gamma, N, C, HxW, group, grad_input_mask):
dX = torch.empty_like(input)
dgamma = torch.empty_like(gamma)
dbeta = torch.empty_like(gamma)
return dX, dgamma, dbeta
# ================================== Misc ==========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.roll.default)

View File

@ -70,6 +70,19 @@ def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
return flops
def baddbmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
"""
Count flops for the baddbmm(batch add and batch matmul) operation.
"""
# Inputs = [input, batch1, batch2]
# out = input + batch1 x batch2
assert len(inputs) == 3, len(inputs)
n, c, t = inputs[1].shape
d = inputs[2].shape[-1]
flops = n * c * t * d
return flops
def conv_flop_count(
x_shape: List[int],
w_shape: List[int],
@ -196,6 +209,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten.matmul.default: matmul_flop_jit,
aten.addmm.default: addmm_flop_jit,
aten.bmm.default: bmm_flop_jit,
aten.baddbmm.default: baddbmm_flop_jit,
# convolution
aten.convolution.default: conv_flop_jit,
@ -209,6 +223,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
aten.native_layer_norm.default: norm_flop_counter(2, 0),
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
aten.native_group_norm.default: norm_flop_counter(2, 0),
aten.native_group_norm_backward.default: norm_flop_counter(2, 0),
# pooling
aten.avg_pool1d.default: elementwise_flop_counter(1, 0),
@ -230,6 +246,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
aten.embedding.default: elementwise_flop_counter(1, 0),
aten.upsample_nearest2d.vec: elementwise_flop_counter(0, 1),
aten.upsample_nearest2d_backward.vec: elementwise_flop_counter(0, 1),
}
elementwise_flop_aten = [
@ -251,6 +269,9 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten.mean.dim,
aten.sub.Tensor,
aten.sub_.Tensor,
aten.exp.default,
aten.sin.default,
aten.cos.default,
# activation op
aten.hardswish.default,