import torch
from packaging import version

__all__ = [
    "_TorchFactoryMethod",
    "_TorchOverrideableFactoryMethod",
    "_TorchNonOverrideableFactoryMethod",
    "_TensorPropertyMethod",
    "_DistCommMethod",
    "_AliasATen",
    "_InplaceATen",
    "_MaybeInplaceATen",
]

_TorchOverrideableFactoryMethod = [
    "empty",
    "eye",
    "full",
    "ones",
    "rand",
    "randn",
    "zeros",
]

_TorchNonOverrideableFactoryMethod = [
    "arange",
    "finfo",
    "linspace",
    "logspace",
    "randint",
    "randperm",
    "tensor",
]

_TorchFactoryMethod = _TorchOverrideableFactoryMethod + _TorchNonOverrideableFactoryMethod

_TensorPropertyMethod = ["dtype", "shape", "device", "requires_grad", "grad", "grad_fn", "data"]

_DistCommMethod = [
    "all_gather",
    "all_reduce",
    "all_to_all",
    "broadcast",
    "gather",
    "reduce",
    "reduce_scatter",
    "scatter",
]

if version.parse(torch.__version__) >= version.parse("1.12.0"):
    aten = torch.ops.aten
    # TODO: dive deep here
    # refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
    _AliasATen = [
        aten.detach.default,
        aten.detach_.default,
        aten.t.default,
        aten.transpose.int,
        aten.view.default,
        aten._unsafe_view.default,
        aten._reshape_alias.default,
    ]

    _InplaceATen = [
        aten.add_.Tensor,
        aten.add_.Scalar,
        aten.sub_.Tensor,
        aten.sub_.Scalar,
        aten.mul_.Tensor,
        aten.mul_.Scalar,
        aten.div_.Tensor,
        aten.div_.Scalar,
        aten.pow_.Tensor,
        aten.pow_.Scalar,
    ]

    # use `MaybeInplace` because they call ``as_strided()`` or ``slice()``
    _MaybeInplaceATen = [
        aten.diagonal.default,
        aten.expand.default,
        aten.select.int,
        aten.slice.Tensor,
        aten.split.Tensor,
        aten.squeeze.default,
        aten.permute.default,
        aten.unsqueeze.default,
        aten.as_strided.default,
    ]
else:
    _AliasATen = []
    _InplaceATen = []
    _MaybeInplaceATen = []