mirror of https://github.com/hpcaitech/ColossalAI
40 lines
1.1 KiB
Python
40 lines
1.1 KiB
Python
![]() |
import torch
|
||
![]() |
|
||
![]() |
from colossalai.legacy.zero.gemini.stateful_tensor import StatefulTensor, TensorState
|
||
![]() |
|
||
|
|
||
![]() |
class ShardedTensor(StatefulTensor):
|
||
![]() |
def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None:
|
||
![]() |
r"""
|
||
![]() |
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
|
||
![]() |
"""
|
||
![]() |
assert tensor.requires_grad is False
|
||
![]() |
super().__init__(tensor, state)
|
||
![]() |
|
||
![]() |
# kept the shape, numel and dtype of the init tensor.
|
||
![]() |
self._origin_shape = tensor.shape
|
||
|
self._origin_numel = tensor.numel()
|
||
|
self._origin_dtype = tensor.dtype
|
||
![]() |
self._is_sharded = False
|
||
|
|
||
![]() |
@property
|
||
|
def dtype(self) -> torch.dtype:
|
||
![]() |
assert self._payload.dtype == self._origin_dtype
|
||
|
return self._payload.dtype
|
||
![]() |
|
||
![]() |
@property
|
||
![]() |
def origin_numel(self) -> int:
|
||
![]() |
return self._origin_numel
|
||
|
|
||
|
@property
|
||
![]() |
def origin_shape(self) -> int:
|
||
![]() |
return self._origin_shape
|
||
|
|
||
![]() |
@property
|
||
|
def is_sharded(self):
|
||
|
return self._is_sharded
|
||
|
|
||
![]() |
@is_sharded.setter
|
||
|
def is_sharded(self, flag: bool):
|
||
|
self._is_sharded = flag
|