from .op_wrapper import _COLOSSAL_OPS
from copy import copy
import torch
from colossalai.tensor import TensorSpec
from .const import TensorType
from colossalai.tensor import distspec
from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import _DistSpec
from torch.overrides import get_default_nowrap_functions
def _convert_output(output):
if isinstance(output, torch.Tensor) and not isinstance(output, ColoTensor):
output = ColoTensor.from_torch_tensor(output)
elif isinstance(output, (list, tuple)):
output = type(output)(_convert_output(o) for o in output)
return output
class ColoTensor(torch.Tensor):
""" Data Structure for Tensor in Colossal-AI
1. It contains a torch.Tensor as an attribute.
2. It supports lazy init the tensor's payload.
3. It can hijack the torch functions which using ColoTensors as args to our customized functions.
4. It supports distributing the tensor's payload to the shards among processes. (TODO)
def __new__(cls, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
if data is None:
data = torch.empty(0)
return torch.Tensor._make_subclass(cls, data, data.requires_grad)
def __init__(self, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> None:
self._spec = copy(spec)
self._type = TensorType.NONMODEL
self._graph_node = None
def spec(self) -> TensorSpec:
return self._spec
def set_spec(self, spec: TensorSpec) -> None:
spec = copy(spec)
self._spec = spec
def has_spec(self) -> bool:
return self._spec.num_action > 0
def is_model_data(self) -> bool:
return self._type == TensorType.MODEL
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if not all(issubclass(cls, t) for t in types):
return NotImplemented
if func in _COLOSSAL_OPS:
func = _COLOSSAL_OPS[func]
with torch._C.DisableTorchFunction():
ret = func(*args, **kwargs)
if func in get_default_nowrap_functions():
return ret
return _convert_output(ret)
def __repr__(self):
return f'ColoTensor: {super().__repr__()}'
def is_model_data(self) -> bool:
return self._type == TensorType.MODEL
def convert_to_dist_spec_(self, dist_spec: _DistSpec) -> None:
with DistSpecManager.no_grad():
self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
self._spec.dist_spec = dist_spec
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
spec = copy(self._spec)
spec.dist_spec = dist_spec
ret = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
return ColoTensor.from_torch_tensor(ret, spec)
def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
tensor = tensor.as_subclass(ColoTensor)
tensor.__init__(tensor, spec=spec)
return tensor
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
with torch._C.DisableTorchFunction():
data = self.data.clone()
tensor = ColoTensor(data, spec=copy(self.spec))
memo[id(self)] = tensor
return tensor