Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

139 lines
4.9 KiB

import uuid
import torch
from torch.types import _device
from torch.utils._pytree import tree_map
from .._compatibility import compatibility
from .constants import ALIAS_ATEN
__all__ = ["MetaTensor"]
def set_data_ptr(x):
if isinstance(x, torch.Tensor):
if not x.data_ptr():
data_ptr = uuid.uuid4()
x.data_ptr = lambda: data_ptr
@compatibility(is_backward_compatible=False)
class MetaTensor(torch.Tensor):
"""
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
`fake_device` is the device that `MetaTensor` is supposed to run on.
"""
_tensor: torch.Tensor
@staticmethod
def __new__(cls, elem, fake_device=None):
# Avoid multiple wrapping
if isinstance(elem, MetaTensor):
fake_device = elem.device if fake_device is None else fake_device
elem = elem._tensor
# The wrapping tensor (MetaTensor) shouldn't hold any
# memory for the class in question, but it should still
# advertise the same device as before
r = torch.Tensor._make_wrapper_subclass(
cls,
elem.size(),
strides=elem.stride(),
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device=fake_device or (elem.device if elem.device.type != "meta" else torch.device("cpu")),
requires_grad=elem.requires_grad,
) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
r._tensor = r._tensor.to(torch.device("meta"))
# only tensor not on `meta` should be copied to `meta`
set_data_ptr(r._tensor)
return r
def __repr__(self):
if self.grad_fn:
return f"MetaTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
return f"MetaTensor(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
fake_device = None
def unwrap(x):
nonlocal fake_device
if isinstance(x, MetaTensor):
fake_device = x.device
x = x._tensor
elif isinstance(x, torch.Tensor):
fake_device = x.device
x = x.to(torch.device("meta"))
return x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
if "device" in kwargs:
fake_device = kwargs["device"]
kwargs["device"] = torch.device("meta")
# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)
# here we keep the uuid of input because ALIAS_ATEN do not generate a physical copy
# of the input
if func in ALIAS_ATEN:
out.data_ptr = args[0].data_ptr
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass
def wrap(x):
if isinstance(x, torch.Tensor):
nonlocal fake_device
if not x.is_meta:
x = x.to(torch.device("meta"))
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
return tree_map(wrap, out)
def to(self, *args, **kwargs) -> torch.Tensor:
"""An extension of `torch.Tensor.to()` to MetaTensor
Returns:
result (MetaTensor): MetaTensor
Usage:
>>> tensor = MetaTensor(torch.rand(10), fake_device='cuda:100')
>>> tensor.to(torch.uint8)
MetaTensor(tensor(..., device='meta', size=(10,), dtype=torch.uint8), fake_device='cuda:100')
>>> tensor.to(torch.device('cuda:42'))
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='cuda:42')
>>> tensor.to('vulkan')
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan')
"""
# this imitates c++ function in the way of @overload
fake_device = None
def replace(x):
nonlocal fake_device
if isinstance(x, str) or isinstance(x, _device):
fake_device = x
return "meta"
return x
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
return MetaTensor(elem, fake_device=fake_device)
def cpu(self, *args, **kwargs):
if self.device.type == "cpu":
return self.to(*args, **kwargs)
return self.to(*args, device="cpu", **kwargs)
def cuda(self, device=None, non_blocking=False):
if device is not None:
return self.to(device=device, non_blocking=non_blocking)
return self.to(device="cuda:0", non_blocking=non_blocking)