|
|
@ -163,6 +163,23 @@ def meta_conv(
|
|
|
|
return out
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_meta(aten._convolution.default)
|
|
|
|
|
|
|
|
def meta_conv_1(
|
|
|
|
|
|
|
|
input_tensor: torch.Tensor,
|
|
|
|
|
|
|
|
weight: torch.Tensor,
|
|
|
|
|
|
|
|
bias: torch.Tensor,
|
|
|
|
|
|
|
|
stride: List[int],
|
|
|
|
|
|
|
|
padding: List[int],
|
|
|
|
|
|
|
|
dilation: List[int],
|
|
|
|
|
|
|
|
is_transposed: bool,
|
|
|
|
|
|
|
|
output_padding: List[int],
|
|
|
|
|
|
|
|
groups: int,
|
|
|
|
|
|
|
|
*extra_args
|
|
|
|
|
|
|
|
):
|
|
|
|
|
|
|
|
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_meta(aten.convolution_backward.default)
|
|
|
|
@register_meta(aten.convolution_backward.default)
|
|
|
|
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
|
|
|
|
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
|
|
|
|
padding, dilation, transposed, output_padding, groups, output_mask):
|
|
|
|
padding, dilation, transposed, output_padding, groups, output_mask):
|
|
|
|