mirror of https://github.com/hpcaitech/ColossalAI
register aten._convolution.default (#2137)
parent
ee287620f0
commit
a128eec9d5
|
@ -163,6 +163,23 @@ def meta_conv(
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@register_meta(aten._convolution.default)
|
||||||
|
def meta_conv_1(
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
bias: torch.Tensor,
|
||||||
|
stride: List[int],
|
||||||
|
padding: List[int],
|
||||||
|
dilation: List[int],
|
||||||
|
is_transposed: bool,
|
||||||
|
output_padding: List[int],
|
||||||
|
groups: int,
|
||||||
|
*extra_args
|
||||||
|
):
|
||||||
|
out = meta_conv(input_tensor, weight, bias, stride, padding, dilation, is_transposed, output_padding, groups)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
@register_meta(aten.convolution_backward.default)
|
@register_meta(aten.convolution_backward.default)
|
||||||
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
|
def meta_conv_backward(grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, bias_sizes, stride,
|
||||||
padding, dilation, transposed, output_padding, groups, output_mask):
|
padding, dilation, transposed, output_padding, groups, output_mask):
|
||||||
|
|
Loading…
Reference in New Issue