You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/_analyzer/_subclasses/_monkey_patch.py

94 lines
2.1 KiB

import torch
import torch.distributed as dist
from packaging import version
__all__ = [
"_TorchFactoryMethod",
"_TorchOverrideableFactoryMethod",
"_TorchNonOverrideableFactoryMethod",
"_TensorPropertyMethod",
"_DistCommMethod",
"_AliasATen",
"_InplaceATen",
"_MaybeInplaceATen",
]
_TorchOverrideableFactoryMethod = [
"empty",
"eye",
"full",
"ones",
"rand",
"randn",
"zeros",
]
_TorchNonOverrideableFactoryMethod = [
"arange",
"finfo",
"linspace",
"logspace",
"randint",
"randperm",
"tensor",
]
_TorchFactoryMethod = _TorchOverrideableFactoryMethod + _TorchNonOverrideableFactoryMethod
_TensorPropertyMethod = ["dtype", "shape", "device", "requires_grad", "grad", "grad_fn", "data"]
_DistCommMethod = [
"all_gather",
"all_reduce",
"all_to_all",
"broadcast",
"gather",
"reduce",
"reduce_scatter",
"scatter",
]
if version.parse(torch.__version__) >= version.parse('1.12.0'):
aten = torch.ops.aten
# TODO: dive deep here
# refer to https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorShape.cpp
_AliasATen = [
aten.detach.default,
aten.detach_.default,
aten.t.default,
aten.transpose.int,
aten.view.default,
aten._unsafe_view.default,
aten._reshape_alias.default,
]
_InplaceATen = [
aten.add_.Tensor,
aten.add_.Scalar,
aten.sub_.Tensor,
aten.sub_.Scalar,
aten.mul_.Tensor,
aten.mul_.Scalar,
aten.div_.Tensor,
aten.div_.Scalar,
aten.pow_.Tensor,
aten.pow_.Scalar,
]
# use `MaybeInplace` because they call ``as_strided()`` or ``slice()``
_MaybeInplaceATen = [
aten.diagonal.default,
aten.expand.default,
aten.select.int,
aten.slice.Tensor,
aten.split.Tensor,
aten.squeeze.default,
aten.permute.default,
aten.unsqueeze.default,
aten.as_strided.default,
]
else:
_AliasATen = []
_InplaceATen = []
_MaybeInplaceATen = []