ColossalAI/colossalai/fx/_meta_regist_13.py

58 lines
1.9 KiB
Python

import torch
from torch._meta_registrations import register_meta
from torch._prims_common import check
aten = torch.ops.aten
# since we fix the torch version to 1.13.1, we have to add unimplemented meta ops
# all these functions are from here https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py
@register_meta([aten.convolution_backward.default])
def meta_convolution_backward(
grad_output_,
input_,
weight_,
bias_sizes_opt,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
output_mask,
):
# High level logic taken from slow_conv3d_backward_cpu which should
# be representative of all convolution_backward impls
backend_grad_input = None
backend_grad_weight = None
backend_grad_bias = None
if output_mask[0]:
backend_grad_input = grad_output_.new_empty(input_.size())
if output_mask[1]:
backend_grad_weight = grad_output_.new_empty(weight_.size())
if output_mask[2]:
backend_grad_bias = grad_output_.new_empty(bias_sizes_opt)
return (backend_grad_input, backend_grad_weight, backend_grad_bias)
@register_meta(aten._adaptive_avg_pool2d_backward.default)
def meta__adaptive_avg_pool2d_backward(grad_out, self):
ndim = grad_out.ndim
for i in range(1, ndim):
check(
grad_out.size(i) > 0,
lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \
size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty",
)
check(
ndim == 3 or ndim == 4,
lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}",
)
check(
self.dtype == grad_out.dtype,
lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}",
)
return self.new_empty(self.shape)