mirror of https://github.com/hpcaitech/ColossalAI
[fx] fix meta tensor registration (#3589)
* [meta] fix torch 1.13.1 * [meta] fix torch 2.0.0 * [meta] fix torch 1.13.0 * [meta] polish codepull/3592/head
parent
36a519b49f
commit
dac127d0ee
|
@ -274,11 +274,15 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||
aten.prelu.default,
|
||||
aten.hardswish.default,
|
||||
aten.hardtanh.default,
|
||||
aten.prelu_backward.default,
|
||||
aten.hardswish_backward.default,
|
||||
aten.hardtanh_backward.default,
|
||||
]
|
||||
|
||||
if version.parse(torch.__version__) < version.parse('2.0.0'):
|
||||
_unregistered_ewise += [
|
||||
aten.prelu_backward.default,
|
||||
]
|
||||
|
||||
@register_meta(_unregistered_ewise)
|
||||
def meta_unregistered_ewise(input: torch.Tensor, *args):
|
||||
return new_like(input)
|
||||
|
@ -331,11 +335,6 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||
def meta_im2col(input: torch.Tensor, kernel_size, dilation, padding, stride):
|
||||
return new_like(input)
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||
@register_meta(aten.eye.m_out)
|
||||
def meta_eye(n: int, m: int, out: torch.Tensor):
|
||||
return out
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||
@register_meta(aten.roll.default)
|
||||
def meta_roll(input: torch.Tensor, shifts, dims):
|
||||
|
@ -352,97 +351,9 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||
result_type = torch.result_type(self, other)
|
||||
return new_like(condition + self + other, dtype=result_type)
|
||||
|
||||
@register_meta(aten.index.Tensor)
|
||||
def meta_index_Tensor(self, indices):
|
||||
assert indices, "at least one index must be provided"
|
||||
# aten::index is the internal advanced indexing implementation
|
||||
# checkIndexTensorTypes and expandTensors
|
||||
result: List[Optional[torch.Tensor]] = []
|
||||
for i, index in enumerate(indices):
|
||||
if index is not None:
|
||||
assert index.dtype in [torch.long, torch.int8, torch.bool],\
|
||||
"tensors used as indices must be long, byte or bool tensors"
|
||||
if index.dtype in [torch.int8, torch.bool]:
|
||||
nonzero = index.nonzero()
|
||||
k = len(result)
|
||||
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
|
||||
for j in range(index.ndim):
|
||||
assert index.shape[j] == self.shape[
|
||||
k +
|
||||
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
|
||||
result.append(nonzero.select(1, j))
|
||||
else:
|
||||
result.append(index)
|
||||
else:
|
||||
result.append(index)
|
||||
indices = result
|
||||
assert len(indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
|
||||
# expand_outplace
|
||||
import torch._refs as refs
|
||||
|
||||
indices = list(refs._maybe_broadcast(*indices))
|
||||
# add missing null tensors
|
||||
while len(indices) < self.ndim:
|
||||
indices.append(None)
|
||||
|
||||
# hasContiguousSubspace
|
||||
# true if all non-null tensors are adjacent
|
||||
# See:
|
||||
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
|
||||
# https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
|
||||
state = 0
|
||||
has_contiguous_subspace = False
|
||||
for index in indices:
|
||||
if state == 0:
|
||||
if index is not None:
|
||||
state = 1
|
||||
elif state == 1:
|
||||
if index is None:
|
||||
state = 2
|
||||
else:
|
||||
if index is not None:
|
||||
break
|
||||
else:
|
||||
has_contiguous_subspace = True
|
||||
|
||||
# transposeToFront
|
||||
# This is the logic that causes the newly inserted dimensions to show up
|
||||
# at the beginning of the tensor, if they're not contiguous
|
||||
if not has_contiguous_subspace:
|
||||
dims = []
|
||||
transposed_indices = []
|
||||
for i, index in enumerate(indices):
|
||||
if index is not None:
|
||||
dims.append(i)
|
||||
transposed_indices.append(index)
|
||||
for i, index in enumerate(indices):
|
||||
if index is None:
|
||||
dims.append(i)
|
||||
transposed_indices.append(index)
|
||||
self = self.permute(dims)
|
||||
indices = transposed_indices
|
||||
|
||||
# AdvancedIndex::AdvancedIndex
|
||||
# Now we can assume the indices have contiguous subspace
|
||||
# This is simplified from AdvancedIndex which goes to more effort
|
||||
# to put the input and indices in a form so that TensorIterator can
|
||||
# take them. If we write a ref for this, probably that logic should
|
||||
# get implemented
|
||||
before_shape: List[int] = []
|
||||
after_shape: List[int] = []
|
||||
replacement_shape: List[int] = []
|
||||
for dim, index in enumerate(indices):
|
||||
if index is None:
|
||||
if replacement_shape:
|
||||
after_shape.append(self.shape[dim])
|
||||
else:
|
||||
before_shape.append(self.shape[dim])
|
||||
else:
|
||||
replacement_shape = list(index.shape)
|
||||
return self.new_empty(before_shape + replacement_shape + after_shape)
|
||||
|
||||
# ============================== Embedding =========================================
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/Embedding.cpp
|
||||
|
||||
@register_meta(aten.embedding_dense_backward.default)
|
||||
def meta_embedding_dense_backward(grad_output: torch.Tensor, indices: torch.Tensor, num_weights, padding_idx,
|
||||
scale_grad_by_freq):
|
||||
|
@ -459,3 +370,99 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||
@register_meta(aten.native_dropout_backward.default)
|
||||
def meta_native_dropout_backward_default(grad: torch.Tensor, mask: torch.Tensor, scale: float):
|
||||
return new_like(grad) # (grad_in)
|
||||
|
||||
if version.parse(torch.__version__) < version.parse('1.13.0'):
|
||||
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
|
||||
@register_meta(aten.eye.m_out)
|
||||
def meta_eye(n: int, m: int, out: torch.Tensor):
|
||||
return out
|
||||
|
||||
@register_meta(aten.index.Tensor)
|
||||
def meta_index_Tensor(self, indices):
|
||||
assert indices, "at least one index must be provided"
|
||||
# aten::index is the internal advanced indexing implementation
|
||||
# checkIndexTensorTypes and expandTensors
|
||||
result: List[Optional[torch.Tensor]] = []
|
||||
for i, index in enumerate(indices):
|
||||
if index is not None:
|
||||
assert index.dtype in [torch.long, torch.int8, torch.bool],\
|
||||
"tensors used as indices must be long, byte or bool tensors"
|
||||
if index.dtype in [torch.int8, torch.bool]:
|
||||
nonzero = index.nonzero()
|
||||
k = len(result)
|
||||
assert k + index.ndim <= self.ndim, f"too many indices for tensor of dimension {self.ndim}"
|
||||
for j in range(index.ndim):
|
||||
assert index.shape[j] == self.shape[
|
||||
k +
|
||||
j], f"The shape of the mask {index.shape} at index {i} does not match the shape of the indexed tensor {self.shape} at index {k + j}"
|
||||
result.append(nonzero.select(1, j))
|
||||
else:
|
||||
result.append(index)
|
||||
else:
|
||||
result.append(index)
|
||||
indices = result
|
||||
assert len(
|
||||
indices) <= self.ndim, f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})"
|
||||
# expand_outplace
|
||||
import torch._refs as refs
|
||||
|
||||
indices = list(refs._maybe_broadcast(*indices))
|
||||
# add missing null tensors
|
||||
while len(indices) < self.ndim:
|
||||
indices.append(None)
|
||||
|
||||
# hasContiguousSubspace
|
||||
# true if all non-null tensors are adjacent
|
||||
# See:
|
||||
# https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
|
||||
# https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
|
||||
state = 0
|
||||
has_contiguous_subspace = False
|
||||
for index in indices:
|
||||
if state == 0:
|
||||
if index is not None:
|
||||
state = 1
|
||||
elif state == 1:
|
||||
if index is None:
|
||||
state = 2
|
||||
else:
|
||||
if index is not None:
|
||||
break
|
||||
else:
|
||||
has_contiguous_subspace = True
|
||||
|
||||
# transposeToFront
|
||||
# This is the logic that causes the newly inserted dimensions to show up
|
||||
# at the beginning of the tensor, if they're not contiguous
|
||||
if not has_contiguous_subspace:
|
||||
dims = []
|
||||
transposed_indices = []
|
||||
for i, index in enumerate(indices):
|
||||
if index is not None:
|
||||
dims.append(i)
|
||||
transposed_indices.append(index)
|
||||
for i, index in enumerate(indices):
|
||||
if index is None:
|
||||
dims.append(i)
|
||||
transposed_indices.append(index)
|
||||
self = self.permute(dims)
|
||||
indices = transposed_indices
|
||||
|
||||
# AdvancedIndex::AdvancedIndex
|
||||
# Now we can assume the indices have contiguous subspace
|
||||
# This is simplified from AdvancedIndex which goes to more effort
|
||||
# to put the input and indices in a form so that TensorIterator can
|
||||
# take them. If we write a ref for this, probably that logic should
|
||||
# get implemented
|
||||
before_shape: List[int] = []
|
||||
after_shape: List[int] = []
|
||||
replacement_shape: List[int] = []
|
||||
for dim, index in enumerate(indices):
|
||||
if index is None:
|
||||
if replacement_shape:
|
||||
after_shape.append(self.shape[dim])
|
||||
else:
|
||||
before_shape.append(self.shape[dim])
|
||||
else:
|
||||
replacement_shape = list(index.shape)
|
||||
return self.new_empty(before_shape + replacement_shape + after_shape)
|
||||
|
|
Loading…
Reference in New Issue