2022-10-18 02:44:23 +00:00
|
|
|
import uuid
|
2022-09-27 02:26:52 +00:00
|
|
|
from copy import deepcopy
|
2022-10-11 03:03:35 +00:00
|
|
|
from typing import Optional
|
2022-10-18 02:44:23 +00:00
|
|
|
|
2022-08-31 08:30:16 +00:00
|
|
|
import torch
|
2022-10-18 02:44:23 +00:00
|
|
|
from torch.types import _bool, _device, _dtype
|
|
|
|
from torch.utils._pytree import tree_flatten, tree_map
|
|
|
|
|
|
|
|
from .._compatibility import compatibility
|
|
|
|
from .constants import ALIAS_ATEN
|
2022-08-31 08:30:16 +00:00
|
|
|
|
|
|
|
__all__ = ['MetaTensor']
|
|
|
|
|
|
|
|
|
2022-11-01 02:43:15 +00:00
|
|
|
def set_data_ptr(x):
|
2022-10-11 03:03:35 +00:00
|
|
|
if isinstance(x, torch.Tensor):
|
2022-11-01 02:43:15 +00:00
|
|
|
if not x.data_ptr():
|
|
|
|
data_ptr = uuid.uuid4()
|
|
|
|
x.data_ptr = lambda: data_ptr
|
2022-10-11 03:03:35 +00:00
|
|
|
|
|
|
|
|
2022-10-18 02:44:23 +00:00
|
|
|
@compatibility(is_backward_compatible=False)
|
2022-08-31 08:30:16 +00:00
|
|
|
class MetaTensor(torch.Tensor):
|
|
|
|
"""
|
|
|
|
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
|
2022-09-23 02:59:47 +00:00
|
|
|
`fake_device` is the device that `MetaTensor` is supposed to run on.
|
2022-08-31 08:30:16 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
_tensor: torch.Tensor
|
2022-09-07 03:21:04 +00:00
|
|
|
|
2022-08-31 08:30:16 +00:00
|
|
|
__slots__ = ['_tensor']
|
2022-09-07 03:21:04 +00:00
|
|
|
|
2022-08-31 08:30:16 +00:00
|
|
|
@staticmethod
|
2022-09-23 02:59:47 +00:00
|
|
|
def __new__(cls, elem, fake_device=None):
|
2022-09-27 02:26:52 +00:00
|
|
|
# Avoid multiple wrapping
|
|
|
|
if isinstance(elem, MetaTensor):
|
|
|
|
fake_device = elem.device if fake_device is None else fake_device
|
|
|
|
elem = elem._tensor
|
|
|
|
|
2022-08-31 08:30:16 +00:00
|
|
|
# The wrapping tensor (MetaTensor) shouldn't hold any
|
|
|
|
# memory for the class in question, but it should still
|
|
|
|
# advertise the same device as before
|
|
|
|
r = torch.Tensor._make_wrapper_subclass(
|
2022-09-07 03:21:04 +00:00
|
|
|
cls,
|
|
|
|
elem.size(),
|
|
|
|
strides=elem.stride(),
|
|
|
|
storage_offset=elem.storage_offset(),
|
|
|
|
dtype=elem.dtype,
|
|
|
|
layout=elem.layout,
|
2022-09-23 02:59:47 +00:00
|
|
|
device=fake_device if fake_device is not None else elem.device,
|
2022-09-07 03:21:04 +00:00
|
|
|
requires_grad=elem.requires_grad) # deceive the frontend for aten selections
|
2022-08-31 08:30:16 +00:00
|
|
|
r._tensor = elem
|
|
|
|
# ...the real tensor is held as an element on the tensor.
|
2022-09-23 02:59:47 +00:00
|
|
|
if not r._tensor.is_meta:
|
|
|
|
r._tensor = r._tensor.to(torch.device('meta'))
|
|
|
|
# only tensor not on `meta` should be copied to `meta`
|
2022-11-01 02:43:15 +00:00
|
|
|
set_data_ptr(r._tensor)
|
2022-08-31 08:30:16 +00:00
|
|
|
return r
|
|
|
|
|
2022-09-07 03:21:04 +00:00
|
|
|
def __repr__(self):
|
|
|
|
if self.grad_fn:
|
2022-09-23 02:59:47 +00:00
|
|
|
return f"MetaTensor({self._tensor}, fake_device='{self.device}', grad_fn={self.grad_fn})"
|
|
|
|
return f"MetaTensor({self._tensor}, fake_device='{self.device}')"
|
2022-09-07 03:21:04 +00:00
|
|
|
|
|
|
|
@classmethod
|
2022-08-31 08:30:16 +00:00
|
|
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
2022-09-23 02:59:47 +00:00
|
|
|
fake_device = None
|
2022-09-07 03:21:04 +00:00
|
|
|
|
2022-08-31 08:30:16 +00:00
|
|
|
def unwrap(x):
|
2022-09-23 02:59:47 +00:00
|
|
|
nonlocal fake_device
|
|
|
|
if isinstance(x, MetaTensor):
|
|
|
|
fake_device = x.device
|
|
|
|
x = x._tensor
|
|
|
|
elif isinstance(x, torch.Tensor):
|
|
|
|
fake_device = x.device
|
|
|
|
x = x.to(torch.device('meta'))
|
|
|
|
return x
|
|
|
|
|
|
|
|
if 'device' in kwargs:
|
|
|
|
fake_device = kwargs['device']
|
|
|
|
kwargs['device'] = torch.device('meta')
|
2022-09-07 03:21:04 +00:00
|
|
|
|
2022-08-31 08:30:16 +00:00
|
|
|
args = tree_map(unwrap, args)
|
|
|
|
kwargs = tree_map(unwrap, kwargs)
|
|
|
|
|
|
|
|
# run aten for backend=CPU but actually on backend=Meta
|
|
|
|
out = func(*args, **kwargs)
|
2022-09-07 03:21:04 +00:00
|
|
|
|
2022-10-11 03:03:35 +00:00
|
|
|
# here we keep the uuid of input because ALIAS_ATEN do not generate a physical copy
|
|
|
|
# of the input
|
|
|
|
if func in ALIAS_ATEN:
|
2022-11-01 02:43:15 +00:00
|
|
|
out.data_ptr = args[0].data_ptr
|
2022-10-11 03:03:35 +00:00
|
|
|
|
2022-08-31 08:30:16 +00:00
|
|
|
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
|
|
|
|
# our custom tensor subclass
|
|
|
|
def wrap(x):
|
2022-09-23 02:59:47 +00:00
|
|
|
if isinstance(x, torch.Tensor):
|
|
|
|
nonlocal fake_device
|
|
|
|
if not x.is_meta:
|
|
|
|
x = x.to(torch.device('meta'))
|
|
|
|
return MetaTensor(x, fake_device=fake_device) if isinstance(x, torch.Tensor) else x
|
2022-09-07 03:21:04 +00:00
|
|
|
|
2022-08-31 08:30:16 +00:00
|
|
|
return tree_map(wrap, out)
|
2022-09-27 02:26:52 +00:00
|
|
|
|
|
|
|
def to(self, *args, **kwargs) -> torch.Tensor:
|
|
|
|
"""An extension of `torch.Tensor.to()` to MetaTensor
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
result (MetaTensor): MetaTensor
|
|
|
|
|
|
|
|
Usage:
|
|
|
|
>>> tensor = MetaTensor(torch.rand(10), fake_device='cuda:100')
|
|
|
|
>>> tensor.to(torch.uint8)
|
|
|
|
MetaTensor(tensor(..., device='meta', size=(10,), dtype=torch.uint8), fake_device='cuda:100')
|
|
|
|
>>> tensor.to(torch.device('cuda:42'))
|
|
|
|
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='cuda:42')
|
|
|
|
>>> tensor.to('vulkan')
|
|
|
|
MetaTensor(tensor(..., device='meta', size=(10,)), fake_device='vulkan')
|
|
|
|
"""
|
|
|
|
# this imitates c++ function in the way of @overload
|
2022-10-11 03:03:35 +00:00
|
|
|
device = None
|
|
|
|
for arg in args:
|
|
|
|
if isinstance(arg, str) or isinstance(arg, _device):
|
|
|
|
device = arg
|
|
|
|
if 'device' in kwargs:
|
|
|
|
device = kwargs['device']
|
|
|
|
result = super().to(*args, **kwargs)
|
|
|
|
if device is not None:
|
2022-10-19 06:24:51 +00:00
|
|
|
result = MetaTensor(result, fake_device=device)
|
2022-10-11 03:03:35 +00:00
|
|
|
return result
|
2022-11-23 02:55:46 +00:00
|
|
|
|
|
|
|
def cpu(self, *args, **kwargs):
|
|
|
|
if self.device.type == 'cpu':
|
|
|
|
return self.to(*args, **kwargs)
|
|
|
|
return self.to(*args, device='cpu', **kwargs)
|
|
|
|
|
|
|
|
def cuda(self, *args, **kwargs):
|
|
|
|
if self.device.type == 'cuda':
|
|
|
|
return self.to(*args, **kwargs)
|
|
|
|
return self.to(*args, device='cuda', **kwargs)
|