mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
88 lines
2.1 KiB
88 lines
2.1 KiB
1 year ago
|
from contextlib import contextmanager
|
||
|
from typing import Callable, Dict, Tuple
|
||
|
|
||
|
import torch
|
||
|
|
||
|
__all__ = [
|
||
|
"_LEGACY_TENSOR_CONSTRUCTOR",
|
||
|
"_NO_META_FACTORY",
|
||
|
"_NORMAL_FACTORY",
|
||
|
"ConstructorManager",
|
||
|
]
|
||
|
|
||
|
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
|
||
|
_NORMAL_FACTORY = [
|
||
|
"arange",
|
||
|
"full",
|
||
|
"empty",
|
||
|
"linspace",
|
||
|
"logspace",
|
||
|
"ones",
|
||
|
"rand",
|
||
|
"randn",
|
||
|
"randint",
|
||
|
"randperm",
|
||
|
"zeros",
|
||
|
"tensor",
|
||
|
]
|
||
|
|
||
|
# factory function that does not support meta tensor backend
|
||
|
_NO_META_FACTORY = [
|
||
|
"eye",
|
||
|
]
|
||
|
|
||
|
_LEGACY_TENSOR_CONSTRUCTOR = {
|
||
|
"FloatTensor": torch.float,
|
||
|
"DoubleTensor": torch.double,
|
||
|
"HalfTensor": torch.half,
|
||
|
"BFloat16Tensor": torch.bfloat16,
|
||
|
"ByteTensor": torch.uint8,
|
||
|
"CharTensor": torch.int8,
|
||
|
"ShortTensor": torch.short,
|
||
|
"IntTensor": torch.int,
|
||
|
"LongTensor": torch.long,
|
||
|
"BoolTensor": torch.bool,
|
||
|
}
|
||
|
|
||
|
|
||
|
class ConstructorManager:
|
||
|
# function name: (new, old)
|
||
|
overwrites: Dict[str, Tuple[Callable, Callable]] = {}
|
||
|
changed: bool = False
|
||
|
|
||
|
@staticmethod
|
||
|
def apply(overwrites: Dict[Callable, Callable]):
|
||
|
ConstructorManager.overwrites.clear()
|
||
|
ConstructorManager.overwrites.update(overwrites)
|
||
|
ConstructorManager.redo()
|
||
|
|
||
|
@staticmethod
|
||
|
def undo():
|
||
|
assert ConstructorManager.changed, "No constructor change to undo"
|
||
|
for name, (new, old) in ConstructorManager.overwrites.items():
|
||
|
setattr(torch, name, old)
|
||
|
ConstructorManager.changed = False
|
||
|
|
||
|
@staticmethod
|
||
|
def redo():
|
||
|
assert not ConstructorManager.changed, "Constructor already changed"
|
||
|
for name, (new, old) in ConstructorManager.overwrites.items():
|
||
|
setattr(torch, name, new)
|
||
|
ConstructorManager.changed = True
|
||
|
|
||
|
@staticmethod
|
||
|
@contextmanager
|
||
|
def disable():
|
||
|
enabled = ConstructorManager.changed
|
||
|
if enabled:
|
||
|
ConstructorManager.undo()
|
||
|
yield
|
||
|
if enabled:
|
||
|
ConstructorManager.redo()
|
||
|
|
||
|
@staticmethod
|
||
|
def clear():
|
||
|
if ConstructorManager.changed:
|
||
|
ConstructorManager.undo()
|
||
|
ConstructorManager.overwrites.clear()
|