from .op_wrapper import _COLOSSAL_OPS from .const import TensorType from copy import copy import torch from torch.overrides import get_default_nowrap_functions from colossalai.tensor import TensorSpec from colossalai.tensor import distspec from colossalai.tensor.dist_spec_mgr import DistSpecManager from colossalai.tensor.distspec import _DistSpec def _convert_output(output): if isinstance(output, torch.Tensor) and not isinstance(output, ColoTensor): output = ColoTensor.from_torch_tensor(output) elif isinstance(output, (list, tuple)): output = type(output)(_convert_output(o) for o in output) return output class ColoTensor(torch.Tensor): """ Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor. Args: data (torch.Tensor): a torch tensor used as the payload the colotensor. spec (TensorSpec, optional): the tensor spec of initialization. Defaults to TensorSpec(distspec.replicate()). The signature of the function has to be consistent with the __new__ except for the 1st arg. The class should be initialized with a torch tensor in the following ways. 1. directly init. >>> colo_t1 = ColoTensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate()) >>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor. >>> shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), >>> dims=[0], >>> num_partitions=[world_size]) >>> tensor_spec = TensorSpec(shard_spec) >>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) 2. use static method from_torch_tensor >>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate()) """ def __new__(cls, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor': """__new__ The signature of the __new__ has to be consistent with the torch.Tensor. Args: data (torch.Tensor): a torch tensor used as the payload the colotensor. spec (TensorSpec, optional): the tensor spec of initialization. Defaults to TensorSpec(distspec.replicate()) Returns: ColoTensor: a ColoTensor wrappers the data. """ if data is None: data = torch.empty(0) return torch.Tensor._make_subclass(cls, data, data.requires_grad) def __init__(self, data: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> None: self._tensor_spec = copy(spec) self._type = TensorType.NONMODEL self._graph_node = None @property def spec(self) -> TensorSpec: return self._tensor_spec def set_spec(self, spec: TensorSpec) -> None: spec = copy(spec) self._convert_to_dist_spec(spec.dist_spec) self._tensor_spec = spec def has_spec(self) -> bool: return self._tensor_spec.parallel_action is not None def is_model_data(self) -> bool: return self._type == TensorType.MODEL @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} if not all(issubclass(cls, t) for t in types): return NotImplemented global _COLOSSAL_OPS if func in _COLOSSAL_OPS: func = _COLOSSAL_OPS[func] with torch._C.DisableTorchFunction(): ret = func(*args, **kwargs) if func in get_default_nowrap_functions(): return ret else: return _convert_output(ret) def __repr__(self): return f'ColoTensor: {super().__repr__()}' def is_model_data(self) -> bool: return self._type == TensorType.MODEL def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None: with DistSpecManager.no_grad(): self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec) self._tensor_spec.dist_spec = dist_spec def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor': tensor_spec = copy(self._tensor_spec) tensor_spec.dist_spec = dist_spec ret = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec) return ColoTensor.from_torch_tensor(ret, tensor_spec) @staticmethod def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor': tensor = tensor.as_subclass(ColoTensor) tensor.__init__(tensor, spec=spec) return tensor def __deepcopy__(self, memo): if id(self) in memo: return memo[id(self)] else: with torch._C.DisableTorchFunction(): data = self.data.clone() tensor = ColoTensor(data, spec=copy(self.spec)) memo[id(self)] = tensor return tensor