You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/tensor/colo_tensor.py

103 lines
3.3 KiB

from functools import lru_cache
from typing import Callable, Set
import torch
INPALCE_MAPPING = {
torch.Tensor.add_: torch.Tensor.add,
torch.Tensor.sub_: torch.Tensor.sub,
torch.Tensor.mul_: torch.Tensor.mul,
torch.Tensor.div_: torch.Tensor.div,
}
@lru_cache(None)
def _get_my_nowrap_functions() -> Set[Callable]:
Tensor = torch.Tensor
return {
Tensor._base.__get__,
Tensor.grad.__get__,
Tensor._grad.__get__,
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
}
def _convert(output):
if isinstance(output, torch.Tensor) and not isinstance(output, ColoTensor):
output.__class__ = ColoTensor
elif isinstance(output, (list, tuple)):
output = type(output)(_convert(o) for o in output)
return output
def _convert_output(output, func):
if func in _get_my_nowrap_functions():
return output
return _convert(output)
class ColoTensor(torch.Tensor):
"""Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
It is only used to trigger the torch function hook.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
"""
torch_major = int(torch.__version__.split(".")[0])
torch_minor = int(torch.__version__.split(".")[1])
def __new__(cls, data: torch.Tensor) -> "ColoTensor":
"""
The signature of the __new__ has to be consistent with the torch.Tensor.
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
Returns:
ColoTensor: a ColoTensor wrappers the data.
"""
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, data.requires_grad)
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if not all(issubclass(cls, t) for t in types):
return NotImplemented
if cls.torch_major > 1 or (cls.torch_major == 1 and cls.torch_minor >= 12):
# in order to trigger pre-op hook in the forward of checkpoint module
# we have to capture the `backward` function
# and make sure that it does not in `torch._C.DisableTorchFunction()` context
if func is torch.Tensor.backward:
assert len(args) == 1 # only has 1 parameter
backward_tensor = torch.Tensor(args[0])
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()}
return backward_tensor.backward(**tensor_kwargs)
# replace the in-place function
if func in INPALCE_MAPPING:
func = INPALCE_MAPPING[func]
# set the 'inplace' kwargs to False
if "inplace" in kwargs:
kwargs["inplace"] = False
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
return _convert_output(ret, func)
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
tensor = ColoTensor(data)
memo[id(self)] = tensor
return tensor