from contextlib import contextmanager from typing import Callable, Dict, Tuple import torch __all__ = [ "_LEGACY_TENSOR_CONSTRUCTOR", "_NO_META_FACTORY", "_NORMAL_FACTORY", "ConstructorManager", ] # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html _NORMAL_FACTORY = [ "arange", "full", "empty", "linspace", "logspace", "ones", "rand", "randn", "randint", "randperm", "zeros", "tensor", ] # factory function that does not support meta tensor backend _NO_META_FACTORY = [ "eye", ] _LEGACY_TENSOR_CONSTRUCTOR = { "FloatTensor": torch.float, "DoubleTensor": torch.double, "HalfTensor": torch.half, "BFloat16Tensor": torch.bfloat16, "ByteTensor": torch.uint8, "CharTensor": torch.int8, "ShortTensor": torch.short, "IntTensor": torch.int, "LongTensor": torch.long, "BoolTensor": torch.bool, } class ConstructorManager: # function name: (new, old) overwrites: Dict[str, Tuple[Callable, Callable]] = {} changed: bool = False @staticmethod def apply(overwrites: Dict[Callable, Callable]): ConstructorManager.overwrites.clear() ConstructorManager.overwrites.update(overwrites) ConstructorManager.redo() @staticmethod def undo(): assert ConstructorManager.changed, "No constructor change to undo" for name, (new, old) in ConstructorManager.overwrites.items(): setattr(torch, name, old) ConstructorManager.changed = False @staticmethod def redo(): assert not ConstructorManager.changed, "Constructor already changed" for name, (new, old) in ConstructorManager.overwrites.items(): setattr(torch, name, new) ConstructorManager.changed = True @staticmethod @contextmanager def disable(): enabled = ConstructorManager.changed if enabled: ConstructorManager.undo() yield if enabled: ConstructorManager.redo() @staticmethod def clear(): if ConstructorManager.changed: ConstructorManager.undo() ConstructorManager.overwrites.clear()