mirror of https://github.com/hpcaitech/ColossalAI
support unet metainfo prop (#2544)
parent
c4b15661d7
commit
fa3d66feb9
|
@ -164,18 +164,9 @@ def meta_conv(
|
|||
|
||||
|
||||
@register_meta(aten._convolution.default)
|
||||
def meta_conv_1(
|
||||
input_tensor: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
dilation: List[int],
|
||||
is_transposed: bool,
|
||||
output_padding: List[int],
|
||||
groups: int,
|
||||
*extra_args
|
||||
):
|
||||
def meta_conv_1(input_tensor: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, stride: List[int],
|
||||
padding: List[int], dilation: List[int], is_transposed: bool, output_padding: List[int], groups: int,
|
||||
*extra_args):
|
||||
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
|
||||
return out
|
||||
|
||||
|
@ -233,11 +224,8 @@ def meta_cuda_rnn(
|
|||
if is_input_packed:
|
||||
out_shape = [batch_sizes_sum, out_size * num_directions]
|
||||
else:
|
||||
out_shape = (
|
||||
[mini_batch, seq_length, out_size * num_directions]
|
||||
if batch_first
|
||||
else [seq_length, mini_batch, out_size * num_directions]
|
||||
)
|
||||
out_shape = ([mini_batch, seq_length, out_size *
|
||||
num_directions] if batch_first else [seq_length, mini_batch, out_size * num_directions])
|
||||
output = input.new_empty(out_shape)
|
||||
|
||||
cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
|
||||
|
@ -372,6 +360,15 @@ def meta_ln_backward(dY: torch.Tensor, input: torch.Tensor, normalized_shape, me
|
|||
return dX, dgamma, dbeta
|
||||
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/group_norm.cpp
|
||||
@register_meta(aten.native_group_norm_backward.default)
|
||||
def meta_gn_backward(dY: torch.Tensor, input: torch.Tensor, mean, rstd, gamma, N, C, HxW, group, grad_input_mask):
|
||||
dX = torch.empty_like(input)
|
||||
dgamma = torch.empty_like(gamma)
|
||||
dbeta = torch.empty_like(gamma)
|
||||
return dX, dgamma, dbeta
|
||||
|
||||
|
||||
# ================================== Misc ==========================================
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||
@register_meta(aten.roll.default)
|
||||
|
|
|
@ -70,6 +70,19 @@ def bmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
|
|||
return flops
|
||||
|
||||
|
||||
def baddbmm_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
|
||||
"""
|
||||
Count flops for the baddbmm(batch add and batch matmul) operation.
|
||||
"""
|
||||
# Inputs = [input, batch1, batch2]
|
||||
# out = input + batch1 x batch2
|
||||
assert len(inputs) == 3, len(inputs)
|
||||
n, c, t = inputs[1].shape
|
||||
d = inputs[2].shape[-1]
|
||||
flops = n * c * t * d
|
||||
return flops
|
||||
|
||||
|
||||
def conv_flop_count(
|
||||
x_shape: List[int],
|
||||
w_shape: List[int],
|
||||
|
@ -196,6 +209,7 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||
aten.matmul.default: matmul_flop_jit,
|
||||
aten.addmm.default: addmm_flop_jit,
|
||||
aten.bmm.default: bmm_flop_jit,
|
||||
aten.baddbmm.default: baddbmm_flop_jit,
|
||||
|
||||
# convolution
|
||||
aten.convolution.default: conv_flop_jit,
|
||||
|
@ -209,6 +223,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||
aten.cudnn_batch_norm_backward.default: partial(batchnorm_flop_jit, training=True),
|
||||
aten.native_layer_norm.default: norm_flop_counter(2, 0),
|
||||
aten.native_layer_norm_backward.default: norm_flop_counter(2, 0),
|
||||
aten.native_group_norm.default: norm_flop_counter(2, 0),
|
||||
aten.native_group_norm_backward.default: norm_flop_counter(2, 0),
|
||||
|
||||
# pooling
|
||||
aten.avg_pool1d.default: elementwise_flop_counter(1, 0),
|
||||
|
@ -230,6 +246,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||
aten._adaptive_avg_pool3d_backward.default: elementwise_flop_counter(0, 1),
|
||||
aten.embedding_dense_backward.default: elementwise_flop_counter(0, 1),
|
||||
aten.embedding.default: elementwise_flop_counter(1, 0),
|
||||
aten.upsample_nearest2d.vec: elementwise_flop_counter(0, 1),
|
||||
aten.upsample_nearest2d_backward.vec: elementwise_flop_counter(0, 1),
|
||||
}
|
||||
|
||||
elementwise_flop_aten = [
|
||||
|
@ -251,6 +269,9 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||
aten.mean.dim,
|
||||
aten.sub.Tensor,
|
||||
aten.sub_.Tensor,
|
||||
aten.exp.default,
|
||||
aten.sin.default,
|
||||
aten.cos.default,
|
||||
|
||||
# activation op
|
||||
aten.hardswish.default,
|
||||
|
|
Loading…
Reference in New Issue