[fx] fix meta tensor registration (#3589)

* [meta] fix torch 1.13.1

* [meta] fix torch 2.0.0

* [meta] fix torch 1.13.0

* [meta] polish code
pull/3592/head
Hongxin Liu 2023-04-18 16:20:36 +08:00 committed by GitHub
parent 36a519b49f
commit dac127d0ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 102 additions and 95 deletions

View File

@ -274,11 +274,15 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten.prelu.default,
aten.hardswish.default,
aten.hardtanh.default,
aten.prelu_backward.default,
aten.hardswish_backward.default,
aten.hardtanh_backward.default,
]
if version.parse(torch.__version__) < version.parse('2.0.0'):
_unregistered_ewise += [
aten.prelu_backward.default,
]
@register_meta(_unregistered_ewise)
def meta_unregistered_ewise(input: torch.Tensor, *args):
return new_like(input)
@ -331,11 +335,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
def meta_im2col(input: torch.Tensor, kernel_size, dilation, padding, stride):
return new_like(input)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.eye.m_out)
def meta_eye(n: int, m: int, out: torch.Tensor):
return out
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.roll.default)
def meta_roll(input: torch.Tensor, shifts, dims):
@ -352,6 +351,32 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
result_type = torch.result_type(self, other)
return new_like(condition + self + other, dtype=result_type)
# ============================== Embedding =========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
@register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
scale_grad_by_freq):
return new((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, layout=grad_output.layout)
# ============================== Dropout ===========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout.default)
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
# notice that mask is bool
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout_backward.default)
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
return new_like(grad) # (grad_in)
if version.parse(torch.__version__) < version.parse('1.13.0'):
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
@register_meta(aten.eye.m_out)
def meta_eye(n: int, m: int, out: torch.Tensor):
return out
@register_meta(aten.index.Tensor)
def meta_index_Tensor(self, indices):
assert indices, "at least one index must be provided"
@ -376,7 +401,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
else:
result.append(index)
indices = result
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
assert len(
indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
# expand_outplace
import torch._refs as refs
@ -440,22 +466,3 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
else:
replacement_shape = list(index.shape)
return self.new_empty(before_shape + replacement_shape + after_shape)
# ============================== Embedding =========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
@register_meta(aten.embedding_dense_backward.default)
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
scale_grad_by_freq):
return new((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, layout=grad_output.layout)
# ============================== Dropout ===========================================
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout.default)
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
# notice that mask is bool
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
@register_meta(aten.native_dropout_backward.default)
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
return new_like(grad) # (grad_in)