|
|
|
import uuid
|
|
|
|
from functools import partial
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
from torch.types import _device
|
|
|
|
from torch.utils._pytree import tree_map
|
|
|
|
|
|
|
|
from ._monkey_patch import _AliasATen, _DistCommMethod, _InplaceATen, _MaybeInplaceATen, _TorchOverrideableFactoryMethod
|
|
|
|
|
|
|
|
__all__ = ["MetaTensor", "MetaTensorMode"]
|
|
|
|
|
|
|
|
|
|
|
|
def register_storage(r, data_ptr_fn=None):
|
|
|
|
if isinstance(r, torch.Tensor):
|
|
|
|
if data_ptr_fn is not None:
|
|
|
|
r.data_ptr = data_ptr_fn
|
|
|
|
elif not r.data_ptr():
|
|
|
|
data_ptr = uuid.uuid1()
|
|
|
|
r.data_ptr = lambda: data_ptr
|
|
|
|
|
|
|
|
|
|
|
|
def _normalize_tuple(x):
|
|
|
|
if not isinstance(x, tuple):
|
|
|
|
return (x,)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
# a hack of inplace execution in PyTorch
|
|
|
|
def _assert_alias(func):
|
|
|
|
return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen) # TODO: check if should be this aggressive
|
|
|
|
|
|
|
|
|
|
|
|
class MetaTensor(torch.Tensor):
|
|
|
|
"""
|
|
|
|
A wrapping tensor that hacks ``torch.autograd`` without patching more ``torch.ops.aten`` ops.
|
|
|
|
`device` is the device that ``MetaTensor`` is supposed to run on. Meta tensors give you the
|
|
|
|
ability to run PyTorch code without having to actually do computation through tensors
|
|
|
|
allocated on a `meta` device. Because the device is `meta`, meta tensors do not model
|
|
|
|
device propagation. ``MetaTensor`` extends its usage by carrying an additional `device`
|
|
|
|
which tracks devices that would have been used.
|
|
|
|
|
|
|
|
Reference:
|
|
|
|
https://github.com/pytorch/pytorch/blob/master/torch/_subclasses/fake_tensor.py
|
|
|
|
"""
|
|
|
|
|
|
|
|
_tensor: torch.Tensor
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def __new__(cls, elem, device=None, data_ptr_fn=None):
|
|
|
|
requires_grad = elem.requires_grad
|
|
|
|
# Avoid multiple wrapping
|
|
|
|
while isinstance(elem, MetaTensor):
|
|
|
|
device = elem.device if device is None else device
|
|
|
|
elem = elem._tensor
|
|
|
|
|
|
|
|
# The wrapping tensor (MetaTensor) shouldn't hold any
|
|
|
|
# memory for the class in question, but it should still
|
|
|
|
# advertise the same device as before
|
|
|
|
r = torch.Tensor._make_wrapper_subclass(
|
|
|
|
cls,
|
|
|
|
elem.size(),
|
|
|
|
strides=elem.stride(),
|
|
|
|
storage_offset=elem.storage_offset(),
|
|
|
|
dtype=elem.dtype,
|
|
|
|
layout=elem.layout,
|
|
|
|
device=device or (elem.device if elem.device.type != "meta" else torch.device("cpu")),
|
|
|
|
requires_grad=requires_grad,
|
|
|
|
) # deceive the frontend for aten selections
|
|
|
|
r._tensor = elem
|
|
|
|
# ...the real tensor is held as an element on the tensor.
|
|
|
|
if not r._tensor.is_meta:
|
|
|
|
val = elem.data_ptr()
|
|
|
|
data_ptr_fn = lambda: val
|
|
|
|
r._tensor = r._tensor.to(torch.device("meta"))
|
|
|
|
|
|
|
|
# only tensor not on `meta` should be copied to `meta`
|
|
|
|
register_storage(r._tensor, data_ptr_fn)
|
|
|
|
if isinstance(elem, torch.nn.Parameter):
|
|
|
|
r = torch.nn.Parameter(r)
|
|
|
|
return r
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
name = "MetaParameter" if getattr(self, "_is_param", False) else "MetaTensor"
|
|
|
|
if self.grad_fn:
|
|
|
|
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
|
|
|
|
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
|
|
|
device = None
|
|
|
|
|
|
|
|
def unwrap(x):
|
|
|
|
nonlocal device
|
|
|
|
if isinstance(x, MetaTensor):
|
|
|
|
device = x.device
|
|
|
|
x = x._tensor
|
|
|
|
elif isinstance(x, torch.Tensor):
|
|
|
|
device = x.device
|
|
|
|
x = x.to(torch.device("meta"))
|
|
|
|
return x
|
|
|
|
|
|
|
|
args = tree_map(unwrap, args)
|
|
|
|
kwargs = tree_map(unwrap, kwargs)
|
|
|
|
|
|
|
|
if "device" in kwargs:
|
|
|
|
device = kwargs["device"]
|
|
|
|
kwargs["device"] = torch.device("meta")
|
|
|
|
|
|
|
|
# run aten for backend=CPU but actually on backend=Meta
|
|
|
|
# here we detect whether or not the execution generates a physical copy
|
|
|
|
# of the input tensor
|
|
|
|
ret = func(*args, **kwargs)
|
|
|
|
|
|
|
|
if _assert_alias(func):
|
|
|
|
val = args[0].data_ptr()
|
|
|
|
tree_map(partial(register_storage, data_ptr_fn=lambda: val), _normalize_tuple(ret))
|
|
|
|
|
|
|
|
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
|
|
|
|
# our custom tensor subclass
|
|
|
|
def wrap(x):
|
|
|
|
return MetaTensor(x, device=device) if isinstance(x, torch.Tensor) else x
|
|
|
|
|
|
|
|
return tree_map(wrap, ret)
|
|
|
|
|
|
|
|
def to(self, *args, **kwargs) -> torch.Tensor:
|
|
|
|
"""An extension of `torch.Tensor.to()` to MetaTensor
|
|
|
|
Returns:
|
|
|
|
result (MetaTensor): MetaTensor
|
|
|
|
Usage:
|
|
|
|
>>> tensor = MetaTensor(torch.rand(10), device='cuda:100')
|
|
|
|
>>> tensor.to(torch.uint8)
|
|
|
|
MetaTensor(tensor(..., device='meta', size=(10,), dtype=torch.uint8), device='cuda:100')
|
|
|
|
>>> tensor.to(torch.device('cuda:42'))
|
|
|
|
MetaTensor(tensor(..., device='meta', size=(10,)), device='cuda:42')
|
|
|
|
>>> tensor.to('vulkan')
|
|
|
|
MetaTensor(tensor(..., device='meta', size=(10,)), device='vulkan')
|
|
|
|
"""
|
|
|
|
# this imitates c++ function in the way of @overload
|
|
|
|
device = None
|
|
|
|
|
|
|
|
def replace(x):
|
|
|
|
nonlocal device
|
|
|
|
if isinstance(x, str) or isinstance(x, _device):
|
|
|
|
device = x
|
|
|
|
return torch.device("meta")
|
|
|
|
return x
|
|
|
|
|
|
|
|
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
|
|
|
|
return MetaTensor(elem, device=device)
|
|
|
|
|
|
|
|
def cpu(self, *args, **kwargs):
|
|
|
|
if self.device.type == "cpu":
|
|
|
|
return self.to(*args, **kwargs)
|
|
|
|
return self.to(*args, device="cpu", **kwargs)
|
|
|
|
|
|
|
|
def cuda(self, device=None, non_blocking=False):
|
|
|
|
if device is not None:
|
|
|
|
return self.to(device=device, non_blocking=non_blocking)
|
|
|
|
return self.to(device="cuda:0", non_blocking=non_blocking)
|
|
|
|
|
|
|
|
def data_ptr(self):
|
|
|
|
return self._tensor.data_ptr()
|
|
|
|
|
|
|
|
|
|
|
|
class MetaTensorMode(object):
|
|
|
|
"""
|
|
|
|
A context manager that enables MetaTensor mode.
|
|
|
|
|
|
|
|
Usage:
|
|
|
|
>>> with MetaTensorMode():
|
|
|
|
>>> # all torch.xxx and torch.distributed.xxx will be replaced by patched functions
|
|
|
|
>>> # and the actual execution will be on torch.device('meta')
|
|
|
|
>>> a = torch.rand(100000, 100000)
|
|
|
|
>>> b = torch.rand(100000, 100000)
|
|
|
|
>>> c = torch.mm(a, b)
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
self.torch_overrides = {} # override torch.xxx
|
|
|
|
self.dist_overrides = {} # override torch.distributed.xxx
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
def _dummy(*args, **kwargs):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def _new(*args, orig_new=torch.empty, **kwargs):
|
|
|
|
return MetaTensor(
|
|
|
|
orig_new(*args, **{**kwargs, "device": "meta"}), device=kwargs.get("device", torch.device("cpu"))
|
|
|
|
)
|
|
|
|
|
|
|
|
for func in _TorchOverrideableFactoryMethod:
|
|
|
|
self.torch_overrides[func] = getattr(torch, func)
|
|
|
|
setattr(torch, func, partial(_new, orig_new=getattr(torch, func)))
|
|
|
|
|
|
|
|
for func in _DistCommMethod:
|
|
|
|
self.dist_overrides[func] = getattr(dist, func)
|
|
|
|
setattr(dist, func, _dummy)
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
|
|
for func, func_impl in self.torch_overrides.items():
|
|
|
|
setattr(torch, func, func_impl)
|
|
|
|
|
|
|
|
for func, func_impl in self.dist_overrides.items():
|
|
|
|
setattr(dist, func, func_impl)
|