mirror of https://github.com/hpcaitech/ColossalAI
register aten._convolution.default (#2137)
parent
ee287620f0
commit
a128eec9d5
|
@ -163,6 +163,23 @@ def meta_conv(
|
|||
return out
|
||||
|
||||
|
||||
@register_meta(aten._convolution.default)
|
||||
def meta_conv_1(
|
||||
input_tensor: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
dilation: List[int],
|
||||
is_transposed: bool,
|
||||
output_padding: List[int],
|
||||
groups: int,
|
||||
*extra_args
|
||||
):
|
||||
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
|
||||
return out
|
||||
|
||||
|
||||
@register_meta(aten.convolution_backward.default)
|
||||
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
|
||||
padding, dilation, transposed, output_padding, groups, output_mask):
|
||||
|
|
Loading…
Reference in New Issue