mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
102 lines
3.3 KiB
102 lines
3.3 KiB
from functools import lru_cache |
|
from typing import Callable, Set |
|
|
|
import torch |
|
|
|
INPALCE_MAPPING = { |
|
torch.Tensor.add_: torch.Tensor.add, |
|
torch.Tensor.sub_: torch.Tensor.sub, |
|
torch.Tensor.mul_: torch.Tensor.mul, |
|
torch.Tensor.div_: torch.Tensor.div, |
|
} |
|
|
|
|
|
@lru_cache(None) |
|
def _get_my_nowrap_functions() -> Set[Callable]: |
|
Tensor = torch.Tensor |
|
return { |
|
Tensor._base.__get__, |
|
Tensor.grad.__get__, |
|
Tensor._grad.__get__, |
|
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor |
|
} |
|
|
|
|
|
def _convert(output): |
|
if isinstance(output, torch.Tensor) and not isinstance(output, ColoTensor): |
|
output.__class__ = ColoTensor |
|
elif isinstance(output, (list, tuple)): |
|
output = type(output)(_convert(o) for o in output) |
|
return output |
|
|
|
|
|
def _convert_output(output, func): |
|
if func in _get_my_nowrap_functions(): |
|
return output |
|
return _convert(output) |
|
|
|
|
|
class ColoTensor(torch.Tensor): |
|
"""Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor. |
|
|
|
It is only used to trigger the torch function hook. |
|
|
|
Args: |
|
data (torch.Tensor): a torch tensor used as the payload the colotensor. |
|
""" |
|
|
|
torch_major = int(torch.__version__.split(".")[0]) |
|
torch_minor = int(torch.__version__.split(".")[1]) |
|
|
|
def __new__(cls, data: torch.Tensor) -> "ColoTensor": |
|
""" |
|
The signature of the __new__ has to be consistent with the torch.Tensor. |
|
|
|
Args: |
|
data (torch.Tensor): a torch tensor used as the payload the colotensor. |
|
|
|
Returns: |
|
ColoTensor: a ColoTensor wrappers the data. |
|
""" |
|
if data is None: |
|
data = torch.empty(0) |
|
return torch.Tensor._make_subclass(cls, data, data.requires_grad) |
|
|
|
@classmethod |
|
def __torch_function__(cls, func, types, args=(), kwargs=None): |
|
if kwargs is None: |
|
kwargs = {} |
|
|
|
if not all(issubclass(cls, t) for t in types): |
|
return NotImplemented |
|
|
|
if cls.torch_major > 1 or (cls.torch_major == 1 and cls.torch_minor >= 12): |
|
# in order to trigger pre-op hook in the forward of checkpoint module |
|
# we have to capture the `backward` function |
|
# and make sure that it does not in `torch._C.DisableTorchFunction()` context |
|
if func is torch.Tensor.backward: |
|
assert len(args) == 1 # only has 1 parameter |
|
backward_tensor = torch.Tensor(args[0]) |
|
tensor_kwargs = {k: torch.Tensor(v) if torch.is_tensor(v) else v for k, v in kwargs.items()} |
|
return backward_tensor.backward(**tensor_kwargs) |
|
|
|
# replace the in-place function |
|
if func in INPALCE_MAPPING: |
|
func = INPALCE_MAPPING[func] |
|
# set the 'inplace' kwargs to False |
|
if "inplace" in kwargs: |
|
kwargs["inplace"] = False |
|
|
|
with torch._C.DisableTorchFunction(): |
|
ret = func(*args, **kwargs) |
|
return _convert_output(ret, func) |
|
|
|
def __deepcopy__(self, memo): |
|
if id(self) in memo: |
|
return memo[id(self)] |
|
else: |
|
with torch._C.DisableTorchFunction(): |
|
data = self.data.clone() |
|
tensor = ColoTensor(data) |
|
memo[id(self)] = tensor |
|
return tensor
|
|
|