from functools import lru_cache from typing import Callable, Set import torch INPALCE_MAPPING = { torch.Tensor.add_: torch.Tensor.add, torch.Tensor.sub_: torch.Tensor.sub, torch.Tensor.mul_: torch.Tensor.mul, torch.Tensor.div_: torch.Tensor.div, } @lru_cache(None) def _get_my_nowrap_functions() -> Set[Callable]: Tensor = torch.Tensor return { Tensor._base.__get__, Tensor.grad.__get__, Tensor._grad.__get__, Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor } def _convert(output): if isinstance(output, torch.Tensor) and not isinstance(output, ColoTensor): output.__class__ = ColoTensor elif isinstance(output, (list, tuple)): output = type(output)(_convert(o) for o in output) return output def _convert_output(output, func): if func in _get_my_nowrap_functions(): return output return _convert(output) class ColoTensor(torch.Tensor): """Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor. It is only used to trigger the torch function hook. Args: data (torch.Tensor): a torch tensor used as the payload the colotensor. """ torch_major = int(torch.__version__.split(".")[0]) torch_minor = int(torch.__version__.split(".")[1]) def __new__(cls, data: torch.Tensor) -> "ColoTensor": """ The signature of the __new__ has to be consistent with the torch.Tensor. Args: data (torch.Tensor): a torch tensor used as the payload the colotensor. Returns: ColoTensor: a ColoTensor wrappers the data. """ if data is None: data = torch.empty(0) return torch.Tensor._make_subclass(cls, data, data.requires_grad) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} if not all(issubclass(cls, t) for t in types): return NotImplemented if cls.torch_major > 1 or (cls.torch_major == 1 and cls.torch_minor >= 12): # in order to trigger pre-op hook in the forward of checkpoint module # we have to capture the `backward` function # and make sure that it does not in `torch._C.DisableTorchFunction()` context if func is torch.Tensor.backward: assert len(args) == 1 # only has 1 parameter backward_tensor = torch.Tensor(args[0]) tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()} return backward_tensor.backward(**tensor_kwargs) # replace the in-place function if func in INPALCE_MAPPING: func = INPALCE_MAPPING[func] # set the 'inplace' kwargs to False if "inplace" in kwargs: kwargs["inplace"] = False with torch._C.DisableTorchFunction(): ret = func(*args, **kwargs) return _convert_output(ret, func) def __deepcopy__(self, memo): if id(self) in memo: return memo[id(self)] else: with torch._C.DisableTorchFunction(): data = self.data.clone() tensor = ColoTensor(data) memo[id(self)] = tensor return tensor