2022-03-04 02:46:13 +00:00
|
|
|
import torch
|
2022-04-24 05:08:48 +00:00
|
|
|
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
|
2022-03-04 02:46:13 +00:00
|
|
|
|
|
|
|
|
2022-03-30 05:51:37 +00:00
|
|
|
class ShardedTensor(StatefulTensor):
|
2022-03-04 02:46:13 +00:00
|
|
|
|
2022-03-30 07:57:46 +00:00
|
|
|
def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None:
|
2022-03-04 02:46:13 +00:00
|
|
|
r"""
|
2022-03-10 03:20:04 +00:00
|
|
|
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
|
2022-03-04 02:46:13 +00:00
|
|
|
"""
|
2022-04-13 06:54:26 +00:00
|
|
|
assert tensor.requires_grad is False
|
2022-03-30 07:57:46 +00:00
|
|
|
super().__init__(tensor, state)
|
2022-03-04 02:46:13 +00:00
|
|
|
|
2022-03-30 07:57:46 +00:00
|
|
|
# kept the shape, numel and dtype of the init tensor.
|
2022-03-04 02:46:13 +00:00
|
|
|
self._origin_shape = tensor.shape
|
|
|
|
self._origin_numel = tensor.numel()
|
|
|
|
self._origin_dtype = tensor.dtype
|
2022-03-30 05:51:37 +00:00
|
|
|
self._is_sharded = False
|
|
|
|
|
2022-04-13 06:54:26 +00:00
|
|
|
@property
|
|
|
|
def dtype(self) -> torch.dtype:
|
2022-04-22 04:12:35 +00:00
|
|
|
assert self._payload.dtype == self._origin_dtype
|
|
|
|
return self._payload.dtype
|
2022-04-13 06:54:26 +00:00
|
|
|
|
2022-03-04 07:35:07 +00:00
|
|
|
@property
|
2022-03-30 05:51:37 +00:00
|
|
|
def origin_numel(self) -> int:
|
2022-03-04 07:35:07 +00:00
|
|
|
return self._origin_numel
|
|
|
|
|
|
|
|
@property
|
2022-03-30 05:51:37 +00:00
|
|
|
def origin_shape(self) -> int:
|
2022-03-04 07:35:07 +00:00
|
|
|
return self._origin_shape
|
|
|
|
|
2022-03-04 02:46:13 +00:00
|
|
|
@property
|
|
|
|
def is_sharded(self):
|
|
|
|
return self._is_sharded
|
|
|
|
|
2022-03-04 07:35:07 +00:00
|
|
|
@is_sharded.setter
|
|
|
|
def is_sharded(self, flag: bool):
|
|
|
|
self._is_sharded = flag
|