2022-04-21 03:42:37 +00:00
|
|
|
import torch
|
2022-04-21 06:15:48 +00:00
|
|
|
from .op_wrapper import _COLOSSAL_OPS
|
2022-04-21 03:42:37 +00:00
|
|
|
|
|
|
|
|
2022-04-21 06:15:48 +00:00
|
|
|
class ColoTensor(object):
|
2022-04-21 03:42:37 +00:00
|
|
|
|
|
|
|
def __new__(cls, *args, **kwargs):
|
2022-04-21 06:15:48 +00:00
|
|
|
return super(ColoTensor, cls).__new__(cls)
|
2022-04-21 03:42:37 +00:00
|
|
|
|
|
|
|
def __init__(self, t: torch.Tensor) -> None:
|
|
|
|
self._torch_tensor = t
|
|
|
|
|
|
|
|
def torch_tensor(self) -> torch.Tensor:
|
|
|
|
return self._torch_tensor
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
2022-04-21 06:15:48 +00:00
|
|
|
global _COLOSSAL_OPS
|
|
|
|
if func in _COLOSSAL_OPS:
|
2022-04-21 03:42:37 +00:00
|
|
|
for arg in args:
|
2022-04-21 06:15:48 +00:00
|
|
|
if isinstance(arg, ColoTensor):
|
|
|
|
return _COLOSSAL_OPS[func](types, args, kwargs, None)
|
2022-04-21 03:42:37 +00:00
|
|
|
|
|
|
|
for kwarg in kwargs.values():
|
2022-04-21 06:15:48 +00:00
|
|
|
if isinstance(kwarg, ColoTensor):
|
|
|
|
return _COLOSSAL_OPS[func](types, args, kwargs, None)
|
2022-04-21 06:25:27 +00:00
|
|
|
else:
|
|
|
|
# If we have not hijact the function, convert the ColoTensors in args and kwargs to torch tensors.
|
|
|
|
args = [arg.torch_tensor() if isinstance(arg, ColoTensor) else arg for arg in args]
|
|
|
|
if kwargs is None:
|
|
|
|
kwargs = {}
|
|
|
|
|
|
|
|
kwargs = {
|
|
|
|
kwarg: kwargs[kwarg].torch_tensor() if isinstance(kwarg, ColoTensor) else kwarg for kwarg in kwargs
|
|
|
|
}
|
|
|
|
return func(*args, **kwargs)
|