mirror of https://github.com/hpcaitech/ColossalAI
[fx] fix meta tensor registration (#3589)
* [meta] fix torch 1.13.1 * [meta] fix torch 2.0.0 * [meta] fix torch 1.13.0 * [meta] polish codepull/3592/head
parent
36a519b49f
commit
dac127d0ee
|
@ -274,11 +274,15 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||
aten.prelu.default,
|
||||
aten.hardswish.default,
|
||||
aten.hardtanh.default,
|
||||
aten.prelu_backward.default,
|
||||
aten.hardswish_backward.default,
|
||||
aten.hardtanh_backward.default,
|
||||
]
|
||||
|
||||
if version.parse(torch.__version__) < version.parse('2.0.0'):
|
||||
_unregistered_ewise += [
|
||||
aten.prelu_backward.default,
|
||||
]
|
||||
|
||||
@register_meta(_unregistered_ewise)
|
||||
def meta_unregistered_ewise(input: torch.Tensor, *args):
|
||||
return new_like(input)
|
||||
|
@ -331,11 +335,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||
def meta_im2col(input: torch.Tensor, kernel_size, dilation, padding, stride):
|
||||
return new_like(input)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||
@register_meta(aten.eye.m_out)
|
||||
def meta_eye(n: int, m: int, out: torch.Tensor):
|
||||
return out
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||
@register_meta(aten.roll.default)
|
||||
def meta_roll(input: torch.Tensor, shifts, dims):
|
||||
|
@ -352,6 +351,32 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||
result_type = torch.result_type(self, other)
|
||||
return new_like(condition + self + other, dtype=result_type)
|
||||
|
||||
# ============================== Embedding =========================================
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
|
||||
|
||||
@register_meta(aten.embedding_dense_backward.default)
|
||||
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
|
||||
scale_grad_by_freq):
|
||||
return new((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, layout=grad_output.layout)
|
||||
|
||||
# ============================== Dropout ===========================================
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
||||
@register_meta(aten.native_dropout.default)
|
||||
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
|
||||
# notice that mask is bool
|
||||
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
||||
@register_meta(aten.native_dropout_backward.default)
|
||||
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
|
||||
return new_like(grad) # (grad_in)
|
||||
|
||||
if version.parse(torch.__version__) < version.parse('1.13.0'):
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||
@register_meta(aten.eye.m_out)
|
||||
def meta_eye(n: int, m: int, out: torch.Tensor):
|
||||
return out
|
||||
|
||||
@register_meta(aten.index.Tensor)
|
||||
def meta_index_Tensor(self, indices):
|
||||
assert indices, "at least one index must be provided"
|
||||
|
@ -376,7 +401,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||
else:
|
||||
result.append(index)
|
||||
indices = result
|
||||
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
|
||||
assert len(
|
||||
indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
|
||||
# expand_outplace
|
||||
import torch._refs as refs
|
||||
|
||||
|
@ -440,22 +466,3 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||
else:
|
||||
replacement_shape = list(index.shape)
|
||||
return self.new_empty(before_shape + replacement_shape + after_shape)
|
||||
|
||||
# ============================== Embedding =========================================
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
|
||||
@register_meta(aten.embedding_dense_backward.default)
|
||||
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
|
||||
scale_grad_by_freq):
|
||||
return new((num_weights, grad_output.size(-1)), dtype=grad_output.dtype, layout=grad_output.layout)
|
||||
|
||||
# ============================== Dropout ===========================================
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
||||
@register_meta(aten.native_dropout.default)
|
||||
def meta_native_dropout_default(input: torch.Tensor, p: float, train: bool = False):
|
||||
# notice that mask is bool
|
||||
return new_like(input), new_like(input, dtype=torch.bool) # (output, mask)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Dropout.cpp
|
||||
@register_meta(aten.native_dropout_backward.default)
|
||||
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
|
||||
return new_like(grad) # (grad_in)
|
||||
|
|
Loading…
Reference in New Issue