mirror of https://github.com/hpcaitech/ColossalAI
51 lines
1.7 KiB
Python
51 lines
1.7 KiB
Python
|
import torch
|
||
|
from torch.utils._pytree import tree_map, tree_flatten
|
||
|
|
||
|
|
||
|
__all__ = ['MetaTensor']
|
||
|
|
||
|
|
||
|
class MetaTensor(torch.Tensor):
|
||
|
"""
|
||
|
A wrapping tensor that hacks `torch.autograd` without patching more `torch.ops.aten` ops.
|
||
|
"""
|
||
|
|
||
|
_tensor: torch.Tensor
|
||
|
|
||
|
__slots__ = ['_tensor']
|
||
|
|
||
|
@staticmethod
|
||
|
def __new__(cls, elem):
|
||
|
# The wrapping tensor (MetaTensor) shouldn't hold any
|
||
|
# memory for the class in question, but it should still
|
||
|
# advertise the same device as before
|
||
|
r = torch.Tensor._make_wrapper_subclass(
|
||
|
cls, elem.size(),
|
||
|
strides=elem.stride(), storage_offset=elem.storage_offset(),
|
||
|
dtype=elem.dtype, layout=elem.layout,
|
||
|
device='cpu', requires_grad=elem.requires_grad
|
||
|
) # deceive the frontend for aten selections
|
||
|
r._tensor = elem
|
||
|
# ...the real tensor is held as an element on the tensor.
|
||
|
return r
|
||
|
|
||
|
@ classmethod
|
||
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||
|
def unwrap(x):
|
||
|
if isinstance(x, torch.Tensor) and not hasattr(x, '_tensor'):
|
||
|
x = MetaTensor(x)
|
||
|
return x._tensor.to('meta') if isinstance(x, MetaTensor) else x
|
||
|
|
||
|
args = tree_map(unwrap, args)
|
||
|
kwargs = tree_map(unwrap, kwargs)
|
||
|
|
||
|
# run aten for backend=CPU but actually on backend=Meta
|
||
|
out = func(*args, **kwargs)
|
||
|
|
||
|
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
|
||
|
# our custom tensor subclass
|
||
|
def wrap(x):
|
||
|
return MetaTensor(x) if isinstance(x, torch.Tensor) else x
|
||
|
|
||
|
return tree_map(wrap, out)
|