ColossalAI/colossalai/fx/profiler/tensor.py

59 lines
1.8 KiB
Python
Raw Normal View History

import torch
from torch.utils._pytree import tree_map, tree_flatten
__all__ = ['MetaTensor']
class MetaTensor(torch.Tensor):
"""
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
"""
_tensor: torch.Tensor
__slots__ = ['_tensor']
@staticmethod
def __new__(cls, elem):
# The wrapping tensor (MetaTensor) shouldn't hold any
# memory for the class in question, but it should still
# advertise the same device as before
r = torch.Tensor._make_wrapper_subclass(
cls,
elem.size(),
strides=elem.stride(),
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device='cpu',
requires_grad=elem.requires_grad) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
return r
def __repr__(self):
if self.grad_fn:
return f"MetaTensor({self._tensor}, grad_fn={self.grad_fn})"
return f"MetaTensor({self._tensor})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(x):
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
x = MetaTensor(x)
return x._tensor.to('meta') if isinstance(x, MetaTensor) else x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
# run aten for backend=CPU but actually on backend=Meta
out = func(*args, **kwargs)
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass
def wrap(x):
return MetaTensor(x) if isinstance(x, torch.Tensor) else x
return tree_map(wrap, out)