ColossalAI/colossalai/legacy/zero/sharded_param/sharded_tensor.py

40 lines
1.1 KiB
Python

import torch
from colossalai.legacy.zero.gemini.stateful_tensor import StatefulTensor, TensorState
class ShardedTensor(StatefulTensor):
def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None:
r"""
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
"""
assert tensor.requires_grad is False
super().__init__(tensor, state)
# kept the shape, numel and dtype of the init tensor.
self._origin_shape = tensor.shape
self._origin_numel = tensor.numel()
self._origin_dtype = tensor.dtype
self._is_sharded = False
@property
def dtype(self) -> torch.dtype:
assert self._payload.dtype == self._origin_dtype
return self._payload.dtype
@property
def origin_numel(self) -> int:
return self._origin_numel
@property
def origin_shape(self) -> int:
return self._origin_shape
@property
def is_sharded(self):
return self._is_sharded
@is_sharded.setter
def is_sharded(self, flag: bool):
self._is_sharded = flag