mirror of https://github.com/hpcaitech/ColossalAI
93 lines
2.0 KiB
Python
93 lines
2.0 KiB
Python
import torch
|
|
from packaging import version
|
|
|
|
__all__ = [
|
|
"_TorchFactoryMethod",
|
|
"_TorchOverrideableFactoryMethod",
|
|
"_TorchNonOverrideableFactoryMethod",
|
|
"_TensorPropertyMethod",
|
|
"_DistCommMethod",
|
|
"_AliasATen",
|
|
"_InplaceATen",
|
|
"_MaybeInplaceATen",
|
|
]
|
|
|
|
_TorchOverrideableFactoryMethod = [
|
|
"empty",
|
|
"eye",
|
|
"full",
|
|
"ones",
|
|
"rand",
|
|
"randn",
|
|
"zeros",
|
|
]
|
|
|
|
_TorchNonOverrideableFactoryMethod = [
|
|
"arange",
|
|
"finfo",
|
|
"linspace",
|
|
"logspace",
|
|
"randint",
|
|
"randperm",
|
|
"tensor",
|
|
]
|
|
|
|
_TorchFactoryMethod = _TorchOverrideableFactoryMethod + _TorchNonOverrideableFactoryMethod
|
|
|
|
_TensorPropertyMethod = ["dtype", "shape", "device", "requires_grad", "grad", "grad_fn", "data"]
|
|
|
|
_DistCommMethod = [
|
|
"all_gather",
|
|
"all_reduce",
|
|
"all_to_all",
|
|
"broadcast",
|
|
"gather",
|
|
"reduce",
|
|
"reduce_scatter",
|
|
"scatter",
|
|
]
|
|
|
|
if version.parse(torch.__version__) >= version.parse("1.12.0"):
|
|
aten = torch.ops.aten
|
|
# TODO: dive deep here
|
|
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
|
|
_AliasATen = [
|
|
aten.detach.default,
|
|
aten.detach_.default,
|
|
aten.t.default,
|
|
aten.transpose.int,
|
|
aten.view.default,
|
|
aten._unsafe_view.default,
|
|
aten._reshape_alias.default,
|
|
]
|
|
|
|
_InplaceATen = [
|
|
aten.add_.Tensor,
|
|
aten.add_.Scalar,
|
|
aten.sub_.Tensor,
|
|
aten.sub_.Scalar,
|
|
aten.mul_.Tensor,
|
|
aten.mul_.Scalar,
|
|
aten.div_.Tensor,
|
|
aten.div_.Scalar,
|
|
aten.pow_.Tensor,
|
|
aten.pow_.Scalar,
|
|
]
|
|
|
|
# use `MaybeInplace` because they call ``as_strided()`` or ``slice()``
|
|
_MaybeInplaceATen = [
|
|
aten.diagonal.default,
|
|
aten.expand.default,
|
|
aten.select.int,
|
|
aten.slice.Tensor,
|
|
aten.split.Tensor,
|
|
aten.squeeze.default,
|
|
aten.permute.default,
|
|
aten.unsqueeze.default,
|
|
aten.as_strided.default,
|
|
]
|
|
else:
|
|
_AliasATen = []
|
|
_InplaceATen = []
|
|
_MaybeInplaceATen = []
|