Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

87 lines
2.1 KiB

from contextlib import contextmanager
from typing import Callable, Dict, Tuple
import torch
__all__ = [
"_LEGACY_TENSOR_CONSTRUCTOR",
"_NO_META_FACTORY",
"_NORMAL_FACTORY",
"ConstructorManager",
]
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
_NORMAL_FACTORY = [
"arange",
"full",
"empty",
"linspace",
"logspace",
"ones",
"rand",
"randn",
"randint",
"randperm",
"zeros",
"tensor",
]
# factory function that does not support meta tensor backend
_NO_META_FACTORY = [
"eye",
]
_LEGACY_TENSOR_CONSTRUCTOR = {
"FloatTensor": torch.float,
"DoubleTensor": torch.double,
"HalfTensor": torch.half,
"BFloat16Tensor": torch.bfloat16,
"ByteTensor": torch.uint8,
"CharTensor": torch.int8,
"ShortTensor": torch.short,
"IntTensor": torch.int,
"LongTensor": torch.long,
"BoolTensor": torch.bool,
}
class ConstructorManager:
# function name: (new, old)
overwrites: Dict[str, Tuple[Callable, Callable]] = {}
changed: bool = False
@staticmethod
def apply(overwrites: Dict[Callable, Callable]):
ConstructorManager.overwrites.clear()
ConstructorManager.overwrites.update(overwrites)
ConstructorManager.redo()
@staticmethod
def undo():
assert ConstructorManager.changed, "No constructor change to undo"
for name, (new, old) in ConstructorManager.overwrites.items():
setattr(torch, name, old)
ConstructorManager.changed = False
@staticmethod
def redo():
assert not ConstructorManager.changed, "Constructor already changed"
for name, (new, old) in ConstructorManager.overwrites.items():
setattr(torch, name, new)
ConstructorManager.changed = True
@staticmethod
@contextmanager
def disable():
enabled = ConstructorManager.changed
if enabled:
ConstructorManager.undo()
yield
if enabled:
ConstructorManager.redo()
@staticmethod
def clear():
if ConstructorManager.changed:
ConstructorManager.undo()
ConstructorManager.overwrites.clear()