mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
103 lines
3.3 KiB
103 lines
3.3 KiB
from functools import lru_cache
|
|
from typing import Callable, Set
|
|
|
|
import torch
|
|
|
|
INPALCE_MAPPING = {
|
|
torch.Tensor.add_: torch.Tensor.add,
|
|
torch.Tensor.sub_: torch.Tensor.sub,
|
|
torch.Tensor.mul_: torch.Tensor.mul,
|
|
torch.Tensor.div_: torch.Tensor.div,
|
|
}
|
|
|
|
|
|
@lru_cache(None)
|
|
def _get_my_nowrap_functions() -> Set[Callable]:
|
|
Tensor = torch.Tensor
|
|
return {
|
|
Tensor._base.__get__,
|
|
Tensor.grad.__get__,
|
|
Tensor._grad.__get__,
|
|
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
|
}
|
|
|
|
|
|
def _convert(output):
|
|
if isinstance(output, torch.Tensor) and not isinstance(output, ColoTensor):
|
|
output.__class__ = ColoTensor
|
|
elif isinstance(output, (list, tuple)):
|
|
output = type(output)(_convert(o) for o in output)
|
|
return output
|
|
|
|
|
|
def _convert_output(output, func):
|
|
if func in _get_my_nowrap_functions():
|
|
return output
|
|
return _convert(output)
|
|
|
|
|
|
class ColoTensor(torch.Tensor):
|
|
"""Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
|
|
|
|
It is only used to trigger the torch function hook.
|
|
|
|
Args:
|
|
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
|
"""
|
|
|
|
torch_major = int(torch.__version__.split(".")[0])
|
|
torch_minor = int(torch.__version__.split(".")[1])
|
|
|
|
def __new__(cls, data: torch.Tensor) -> "ColoTensor":
|
|
"""
|
|
The signature of the __new__ has to be consistent with the torch.Tensor.
|
|
|
|
Args:
|
|
data (torch.Tensor): a torch tensor used as the payload the colotensor.
|
|
|
|
Returns:
|
|
ColoTensor: a ColoTensor wrappers the data.
|
|
"""
|
|
if data is None:
|
|
data = torch.empty(0)
|
|
return torch.Tensor._make_subclass(cls, data, data.requires_grad)
|
|
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
if not all(issubclass(cls, t) for t in types):
|
|
return NotImplemented
|
|
|
|
if cls.torch_major > 1 or (cls.torch_major == 1 and cls.torch_minor >= 12):
|
|
# in order to trigger pre-op hook in the forward of checkpoint module
|
|
# we have to capture the `backward` function
|
|
# and make sure that it does not in `torch._C.DisableTorchFunction()` context
|
|
if func is torch.Tensor.backward:
|
|
assert len(args) == 1 # only has 1 parameter
|
|
backward_tensor = torch.Tensor(args[0])
|
|
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()}
|
|
return backward_tensor.backward(**tensor_kwargs)
|
|
|
|
# replace the in-place function
|
|
if func in INPALCE_MAPPING:
|
|
func = INPALCE_MAPPING[func]
|
|
# set the 'inplace' kwargs to False
|
|
if "inplace" in kwargs:
|
|
kwargs["inplace"] = False
|
|
|
|
with torch._C.DisableTorchFunction():
|
|
ret = func(*args, **kwargs)
|
|
return _convert_output(ret, func)
|
|
|
|
def __deepcopy__(self, memo):
|
|
if id(self) in memo:
|
|
return memo[id(self)]
|
|
else:
|
|
with torch._C.DisableTorchFunction():
|
|
data = self.data.clone()
|
|
tensor = ColoTensor(data)
|
|
memo[id(self)] = tensor
|
|
return tensor
|