import uuid from functools import partial import torch import torch.distributed as dist from torch.types import _device from torch.utils._pytree import tree_map from ._monkey_patch import _AliasATen, _DistCommMethod, _InplaceATen, _MaybeInplaceATen, _TorchOverrideableFactoryMethod __all__ = ["MetaTensor", "MetaTensorMode"] def register_storage(r, data_ptr_fn=None): if isinstance(r, torch.Tensor): if data_ptr_fn is not None: r.data_ptr = data_ptr_fn elif not r.data_ptr(): data_ptr = uuid.uuid1() r.data_ptr = lambda: data_ptr def _normalize_tuple(x): if not isinstance(x, tuple): return (x,) return x # a hack of inplace execution in PyTorch def _assert_alias(func): return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen) # TODO: check if should be this aggressive class MetaTensor(torch.Tensor): """ A wrapping tensor that hacks ``torch.autograd`` without patching more ``torch.ops.aten`` ops. `device` is the device that ``MetaTensor`` is supposed to run on. Meta tensors give you the ability to run PyTorch code without having to actually do computation through tensors allocated on a `meta` device. Because the device is `meta`, meta tensors do not model device propagation. ``MetaTensor`` extends its usage by carrying an additional `device` which tracks devices that would have been used. Reference: https://github.com/pytorch/pytorch/blob/master/torch/_subclasses/fake_tensor.py """ _tensor: torch.Tensor @staticmethod def __new__(cls, elem, device=None, data_ptr_fn=None): requires_grad = elem.requires_grad # Avoid multiple wrapping while isinstance(elem, MetaTensor): device = elem.device if device is None else device elem = elem._tensor # The wrapping tensor (MetaTensor) shouldn't hold any # memory for the class in question, but it should still # advertise the same device as before r = torch.Tensor._make_wrapper_subclass( cls, elem.size(), strides=elem.stride(), storage_offset=elem.storage_offset(), dtype=elem.dtype, layout=elem.layout, device=device or (elem.device if elem.device.type != "meta" else torch.device("cpu")), requires_grad=requires_grad, ) # deceive the frontend for aten selections r._tensor = elem # ...the real tensor is held as an element on the tensor. if not r._tensor.is_meta: val = elem.data_ptr() data_ptr_fn = lambda: val r._tensor = r._tensor.to(torch.device("meta")) # only tensor not on `meta` should be copied to `meta` register_storage(r._tensor, data_ptr_fn) if isinstance(elem, torch.nn.Parameter): r = torch.nn.Parameter(r) return r def __repr__(self): name = "MetaParameter" if getattr(self, "_is_param", False) else "MetaTensor" if self.grad_fn: return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})" return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})" @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): device = None def unwrap(x): nonlocal device if isinstance(x, MetaTensor): device = x.device x = x._tensor elif isinstance(x, torch.Tensor): device = x.device x = x.to(torch.device("meta")) return x args = tree_map(unwrap, args) kwargs = tree_map(unwrap, kwargs) if "device" in kwargs: device = kwargs["device"] kwargs["device"] = torch.device("meta") # run aten for backend=CPU but actually on backend=Meta # here we detect whether or not the execution generates a physical copy # of the input tensor ret = func(*args, **kwargs) if _assert_alias(func): val = args[0].data_ptr() tree_map(partial(register_storage, data_ptr_fn=lambda: val), _normalize_tuple(ret)) # Now, we want to continue propagating this tensor, so we rewrap Tensors in # our custom tensor subclass def wrap(x): return MetaTensor(x, device=device) if isinstance(x, torch.Tensor) else x return tree_map(wrap, ret) def to(self, *args, **kwargs) -> torch.Tensor: """An extension of `torch.Tensor.to()` to MetaTensor Returns: result (MetaTensor): MetaTensor Usage: >>> tensor = MetaTensor(torch.rand(10), device='cuda:100') >>> tensor.to(torch.uint8) MetaTensor(tensor(..., device='meta', size=(10,), dtype=torch.uint8), device='cuda:100') >>> tensor.to(torch.device('cuda:42')) MetaTensor(tensor(..., device='meta', size=(10,)), device='cuda:42') >>> tensor.to('vulkan') MetaTensor(tensor(..., device='meta', size=(10,)), device='vulkan') """ # this imitates c++ function in the way of @overload device = None def replace(x): nonlocal device if isinstance(x, str) or isinstance(x, _device): device = x return torch.device("meta") return x elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs)) return MetaTensor(elem, device=device) def cpu(self, *args, **kwargs): if self.device.type == "cpu": return self.to(*args, **kwargs) return self.to(*args, device="cpu", **kwargs) def cuda(self, device=None, non_blocking=False): if device is not None: return self.to(device=device, non_blocking=non_blocking) return self.to(device="cuda:0", non_blocking=non_blocking) def data_ptr(self): return self._tensor.data_ptr() class MetaTensorMode(object): """ A context manager that enables MetaTensor mode. Usage: >>> with MetaTensorMode(): >>> # all torch.xxx and torch.distributed.xxx will be replaced by patched functions >>> # and the actual execution will be on torch.device('meta') >>> a = torch.rand(100000, 100000) >>> b = torch.rand(100000, 100000) >>> c = torch.mm(a, b) """ def __init__(self): self.torch_overrides = {} # override torch.xxx self.dist_overrides = {} # override torch.distributed.xxx def __enter__(self): def _dummy(*args, **kwargs): pass def _new(*args, orig_new=torch.empty, **kwargs): return MetaTensor( orig_new(*args, **{**kwargs, "device": "meta"}), device=kwargs.get("device", torch.device("cpu")) ) for func in _TorchOverrideableFactoryMethod: self.torch_overrides[func] = getattr(torch, func) setattr(torch, func, partial(_new, orig_new=getattr(torch, func))) for func in _DistCommMethod: self.dist_overrides[func] = getattr(dist, func) setattr(dist, func, _dummy) def __exit__(self, exc_type, exc_value, traceback): for func, func_impl in self.torch_overrides.items(): setattr(torch, func, func_impl) for func, func_impl in self.dist_overrides.items(): setattr(dist, func, func_impl)