You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/_analyzer/_subclasses/meta_tensor.py

206 lines
7.4 KiB

import uuid
from functools import partial
import torch
import torch.distributed as dist
from torch.types import _device
from torch.utils._pytree import tree_map
from ._monkey_patch import _AliasATen, _DistCommMethod, _InplaceATen, _MaybeInplaceATen, _TorchOverrideableFactoryMethod
__all__ = ["MetaTensor", "MetaTensorMode"]
def register_storage(r, data_ptr_fn=None):
if isinstance(r, torch.Tensor):
if data_ptr_fn is not None:
r.data_ptr = data_ptr_fn
elif not r.data_ptr():
data_ptr = uuid.uuid1()
r.data_ptr = lambda: data_ptr
def _normalize_tuple(x):
if not isinstance(x, tuple):
return (x,)
return x
# a hack of inplace execution in PyTorch
def _assert_alias(func):
return func in (_AliasATen + _InplaceATen + _MaybeInplaceATen) # TODO: check if should be this aggressive
class MetaTensor(torch.Tensor):
"""
A wrapping tensor that hacks ``torch.autograd`` without patching more ``torch.ops.aten`` ops.
`device` is the device that ``MetaTensor`` is supposed to run on. Meta tensors give you the
ability to run PyTorch code without having to actually do computation through tensors
allocated on a `meta` device. Because the device is `meta`, meta tensors do not model
device propagation. ``MetaTensor`` extends its usage by carrying an additional `device`
which tracks devices that would have been used.
Reference:
https://github.com/pytorch/pytorch/blob/master/torch/_subclasses/fake_tensor.py
"""
_tensor: torch.Tensor
@staticmethod
def __new__(cls, elem, device=None, data_ptr_fn=None):
requires_grad = elem.requires_grad
# Avoid multiple wrapping
while isinstance(elem, MetaTensor):
device = elem.device if device is None else device
elem = elem._tensor
# The wrapping tensor (MetaTensor) shouldn't hold any
# memory for the class in question, but it should still
# advertise the same device as before
r = torch.Tensor._make_wrapper_subclass(
cls,
elem.size(),
strides=elem.stride(),
storage_offset=elem.storage_offset(),
dtype=elem.dtype,
layout=elem.layout,
device=device or (elem.device if elem.device.type != "meta" else torch.device("cpu")),
requires_grad=requires_grad,
) # deceive the frontend for aten selections
r._tensor = elem
# ...the real tensor is held as an element on the tensor.
if not r._tensor.is_meta:
val = elem.data_ptr()
data_ptr_fn = lambda: val
r._tensor = r._tensor.to(torch.device("meta"))
# only tensor not on `meta` should be copied to `meta`
register_storage(r._tensor, data_ptr_fn)
if isinstance(elem, torch.nn.Parameter):
r = torch.nn.Parameter(r)
return r
def __repr__(self):
name = "MetaParameter" if getattr(self, "_is_param", False) else "MetaTensor"
if self.grad_fn:
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype}, grad_fn={self.grad_fn})"
return f"{name}(..., size={tuple(self.shape)}, device='{self.device}', dtype={self.dtype})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
device = None
def unwrap(x):
nonlocal device
if isinstance(x, MetaTensor):
device = x.device
x = x._tensor
elif isinstance(x, torch.Tensor):
device = x.device
x = x.to(torch.device("meta"))
return x
args = tree_map(unwrap, args)
kwargs = tree_map(unwrap, kwargs)
if "device" in kwargs:
device = kwargs["device"]
kwargs["device"] = torch.device("meta")
# run aten for backend=CPU but actually on backend=Meta
# here we detect whether or not the execution generates a physical copy
# of the input tensor
ret = func(*args, **kwargs)
if _assert_alias(func):
val = args[0].data_ptr()
tree_map(partial(register_storage, data_ptr_fn=lambda: val), _normalize_tuple(ret))
# Now, we want to continue propagating this tensor, so we rewrap Tensors in
# our custom tensor subclass
def wrap(x):
return MetaTensor(x, device=device) if isinstance(x, torch.Tensor) else x
return tree_map(wrap, ret)
def to(self, *args, **kwargs) -> torch.Tensor:
"""An extension of `torch.Tensor.to()` to MetaTensor
Returns:
result (MetaTensor): MetaTensor
Usage:
>>> tensor = MetaTensor(torch.rand(10), device='cuda:100')
>>> tensor.to(torch.uint8)
MetaTensor(tensor(..., device='meta', size=(10,), dtype=torch.uint8), device='cuda:100')
>>> tensor.to(torch.device('cuda:42'))
MetaTensor(tensor(..., device='meta', size=(10,)), device='cuda:42')
>>> tensor.to('vulkan')
MetaTensor(tensor(..., device='meta', size=(10,)), device='vulkan')
"""
# this imitates c++ function in the way of @overload
device = None
def replace(x):
nonlocal device
if isinstance(x, str) or isinstance(x, _device):
device = x
return torch.device("meta")
return x
elem = self._tensor.to(*tree_map(replace, args), **tree_map(replace, kwargs))
return MetaTensor(elem, device=device)
def cpu(self, *args, **kwargs):
if self.device.type == "cpu":
return self.to(*args, **kwargs)
return self.to(*args, device="cpu", **kwargs)
def cuda(self, device=None, non_blocking=False):
if device is not None:
return self.to(device=device, non_blocking=non_blocking)
return self.to(device="cuda:0", non_blocking=non_blocking)
def data_ptr(self):
return self._tensor.data_ptr()
class MetaTensorMode(object):
"""
A context manager that enables MetaTensor mode.
Usage:
>>> with MetaTensorMode():
>>> # all torch.xxx and torch.distributed.xxx will be replaced by patched functions
>>> # and the actual execution will be on torch.device('meta')
>>> a = torch.rand(100000, 100000)
>>> b = torch.rand(100000, 100000)
>>> c = torch.mm(a, b)
"""
def __init__(self):
self.torch_overrides = {} # override torch.xxx
self.dist_overrides = {} # override torch.distributed.xxx
def __enter__(self):
def _dummy(*args, **kwargs):
pass
def _new(*args, orig_new=torch.empty, **kwargs):
return MetaTensor(
orig_new(*args, **{**kwargs, "device": "meta"}), device=kwargs.get("device", torch.device("cpu"))
)
for func in _TorchOverrideableFactoryMethod:
self.torch_overrides[func] = getattr(torch, func)
setattr(torch, func, partial(_new, orig_new=getattr(torch, func)))
for func in _DistCommMethod:
self.dist_overrides[func] = getattr(dist, func)
setattr(dist, func, _dummy)
def __exit__(self, exc_type, exc_value, traceback):
for func, func_impl in self.torch_overrides.items():
setattr(torch, func, func_impl)
for func, func_impl in self.dist_overrides.items():
setattr(dist, func, func_impl)