mirror of https://github.com/hpcaitech/ColossalAI
[fx] meta registration compatibility (#3253)
* [fx] meta registration compatibility * fix errorpull/3257/head
parent
73d3e4d309
commit
02b058032d
|
@ -2,11 +2,21 @@ from typing import Callable
|
|||
|
||||
import torch
|
||||
|
||||
try:
|
||||
from . import _meta_registrations
|
||||
META_COMPATIBILITY = True
|
||||
except:
|
||||
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
||||
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
||||
|
||||
if TORCH_MAJOR == 1 and TORCH_MINOR < 12:
|
||||
META_COMPATIBILITY = False
|
||||
elif TORCH_MAJOR == 1 and TORCH_MINOR == 12:
|
||||
from . import _meta_regist_12
|
||||
META_COMPATIBILITY = True
|
||||
elif TORCH_MAJOR == 1 and TORCH_MINOR == 13:
|
||||
from . import _meta_regist_13
|
||||
META_COMPATIBILITY = True
|
||||
elif TORCH_MAJOR == 2:
|
||||
from . import _meta_regist_13
|
||||
META_COMPATIBILITY = True
|
||||
raise UserWarning("Colossalai is not tested with torch2.0 yet!!!")
|
||||
|
||||
|
||||
def compatibility(is_backward_compatible: bool = False) -> Callable:
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
import torch
|
||||
from torch._meta_registrations import register_meta
|
||||
from torch._prims_common import check
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
|
||||
# since we fix the torch version to 1.13.1, we have to add unimplemented meta ops
|
||||
# all these functions are from here https://github.com/pytorch/pytorch/blob/master/torch/_meta_registrations.py
|
||||
@register_meta([aten.convolution_backward.default])
|
||||
def meta_convolution_backward(
|
||||
grad_output_,
|
||||
input_,
|
||||
weight_,
|
||||
bias_sizes_opt,
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
transposed,
|
||||
output_padding,
|
||||
groups,
|
||||
output_mask,
|
||||
):
|
||||
# High level logic taken from slow_conv3d_backward_cpu which should
|
||||
# be representative of all convolution_backward impls
|
||||
backend_grad_input = None
|
||||
backend_grad_weight = None
|
||||
backend_grad_bias = None
|
||||
|
||||
if output_mask[0]:
|
||||
backend_grad_input = grad_output_.new_empty(input_.size())
|
||||
if output_mask[1]:
|
||||
backend_grad_weight = grad_output_.new_empty(weight_.size())
|
||||
if output_mask[2]:
|
||||
backend_grad_bias = grad_output_.new_empty(bias_sizes_opt)
|
||||
|
||||
return (backend_grad_input, backend_grad_weight, backend_grad_bias)
|
||||
|
||||
|
||||
@register_meta(aten._adaptive_avg_pool2d_backward.default)
|
||||
def meta__adaptive_avg_pool2d_backward(grad_out, self):
|
||||
ndim = grad_out.ndim
|
||||
for i in range(1, ndim):
|
||||
check(
|
||||
grad_out.size(i) > 0,
|
||||
lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \
|
||||
size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty",
|
||||
)
|
||||
check(
|
||||
ndim == 3 or ndim == 4,
|
||||
lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}",
|
||||
)
|
||||
check(
|
||||
self.dtype == grad_out.dtype,
|
||||
lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}",
|
||||
)
|
||||
return self.new_empty(self.shape)
|
Loading…
Reference in New Issue