import torch from colossalai.zero.legacy.gemini.stateful_tensor import StatefulTensor, TensorState class ShardedTensor(StatefulTensor): def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None: r""" A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance. """ assert tensor.requires_grad is False super().__init__(tensor, state) # kept the shape, numel and dtype of the init tensor. self._origin_shape = tensor.shape self._origin_numel = tensor.numel() self._origin_dtype = tensor.dtype self._is_sharded = False @property def dtype(self) -> torch.dtype: assert self._payload.dtype == self._origin_dtype return self._payload.dtype @property def origin_numel(self) -> int: return self._origin_numel @property def origin_shape(self) -> int: return self._origin_shape @property def is_sharded(self): return self._is_sharded @is_sharded.setter def is_sharded(self, flag: bool): self._is_sharded = flag