import torch from torch._meta_registrations import register_meta from torch._prims_common import check aten = torch.ops.aten # since we fix the torch version to 1.13.1, we have to add unimplemented meta ops # all these functions are from here https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py @register_meta([aten.convolution_backward.default]) def meta_convolution_backward( grad_output_, input_, weight_, bias_sizes_opt, stride, padding, dilation, transposed, output_padding, groups, output_mask, ): # High level logic taken from slow_conv3d_backward_cpu which should # be representative of all convolution_backward impls backend_grad_input = None backend_grad_weight = None backend_grad_bias = None if output_mask[0]: backend_grad_input = grad_output_.new_empty(input_.size()) if output_mask[1]: backend_grad_weight = grad_output_.new_empty(weight_.size()) if output_mask[2]: backend_grad_bias = grad_output_.new_empty(bias_sizes_opt) return (backend_grad_input, backend_grad_weight, backend_grad_bias) @register_meta(aten._adaptive_avg_pool2d_backward.default) def meta__adaptive_avg_pool2d_backward(grad_out, self): ndim = grad_out.ndim for i in range(1, ndim): check( grad_out.size(i) > 0, lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \ size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty", ) check( ndim == 3 or ndim == 4, lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}", ) check( self.dtype == grad_out.dtype, lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}", ) return self.new_empty(self.shape)