mirror of https://github.com/hpcaitech/ColossalAI
revert zero tensors back (#829)
parent
294a6060d0
commit
595bedf767
|
@ -20,8 +20,8 @@ class ShardedTensor(StatefulTensor):
|
|||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
assert self.torch_tensor().dtype == self._origin_dtype
|
||||
return self.torch_tensor().dtype
|
||||
assert self._payload.dtype == self._origin_dtype
|
||||
return self._payload.dtype
|
||||
|
||||
@property
|
||||
def origin_numel(self) -> int:
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
import torch
|
||||
from colossalai.tensor import ColoTensor
|
||||
|
||||
|
||||
class TensorState(Enum):
|
||||
|
@ -12,7 +11,7 @@ class TensorState(Enum):
|
|||
COMPUTE = 4
|
||||
|
||||
|
||||
class StatefulTensor(ColoTensor):
|
||||
class StatefulTensor(object):
|
||||
"""A Structure stores a Torch Tensor and labeled states.
|
||||
Inspired from the paper:
|
||||
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
||||
|
@ -21,20 +20,15 @@ class StatefulTensor(ColoTensor):
|
|||
"""
|
||||
|
||||
def __init__(self, tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
|
||||
if tensor is not None:
|
||||
super().__init__(tensor.size(), dtype=tensor.dtype, requires_grad=tensor.requires_grad, \
|
||||
pin_memory=tensor.pin_memory, torch_tensor=tensor)
|
||||
else:
|
||||
super().__init__(0)
|
||||
|
||||
self._state = state
|
||||
self._payload = tensor
|
||||
if self._state == TensorState.FREE:
|
||||
assert self.torch_tensor().numel() == 0, f"payload has to None if state is {self._state}"
|
||||
assert self._payload is None, f"payload has to None if state is {self._state}"
|
||||
|
||||
def data_ptr(self):
|
||||
if self.torch_tensor().numel() == 0:
|
||||
if self._payload is None:
|
||||
return None
|
||||
return self.torch_tensor().data_ptr()
|
||||
return self._payload.data_ptr()
|
||||
|
||||
@property
|
||||
def state(self) -> TensorState:
|
||||
|
@ -42,41 +36,42 @@ class StatefulTensor(ColoTensor):
|
|||
|
||||
def set_null(self) -> None:
|
||||
self._state = TensorState.FREE
|
||||
self.del_torch_tensor()
|
||||
self._payload = None
|
||||
|
||||
def is_null(self) -> bool:
|
||||
if self._state == TensorState.FREE:
|
||||
assert self.torch_tensor().numel() == 0
|
||||
assert self._payload is None
|
||||
return True
|
||||
return False
|
||||
|
||||
def trans_state(self, state: TensorState) -> None:
|
||||
self._state = state
|
||||
if state == TensorState.FREE:
|
||||
self.del_torch_tensor()
|
||||
self._payload = None
|
||||
|
||||
@property
|
||||
def payload(self) -> Optional[torch.Tensor]:
|
||||
return self.torch_tensor()
|
||||
return self._payload
|
||||
|
||||
def copy_payload(self, tensor) -> None:
|
||||
self.torch_tensor.view(-1).copy_(tensor.view(-1))
|
||||
self._payload.view(-1).copy_(tensor.view(-1))
|
||||
|
||||
def reset_payload(self, tensor) -> None:
|
||||
self._torch_tensor = tensor
|
||||
del self._payload
|
||||
self._payload = tensor
|
||||
self.trans_state(TensorState.HOLD)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.torch_tensor().device
|
||||
return self._payload.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self.torch_tensor().dtype
|
||||
return self._payload.dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.torch_tensor().shape
|
||||
return self._payload.shape
|
||||
|
||||
def to(self, device: torch.device):
|
||||
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
|
||||
|
|
Loading…
Reference in New Issue