register aten._convolution.default (#2137)

pull/2147/head
Zihao 2 years ago committed by GitHub
parent ee287620f0
commit a128eec9d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -163,6 +163,23 @@ def meta_conv(
return out
@register_meta(aten._convolution.default)
def meta_conv_1(
input_tensor: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
stride: List[int],
padding: List[int],
dilation: List[int],
is_transposed: bool,
output_padding: List[int],
groups: int,
*extra_args
):
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
return out
@register_meta(aten.convolution_backward.default)
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
padding, dilation, transposed, output_padding, groups, output_mask):

Loading…
Cancel
Save