revert zero tensors back (#829)

pull/834/head
Jiarui Fang 2022-04-22 12:12:35 +08:00 committed by GitHub
parent 294a6060d0
commit 595bedf767
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 17 additions and 22 deletions

View File

@ -20,8 +20,8 @@ class ShardedTensor(StatefulTensor):
@property
def dtype(self) -> torch.dtype:
assert self.torch_tensor().dtype == self._origin_dtype
return self.torch_tensor().dtype
assert self._payload.dtype == self._origin_dtype
return self._payload.dtype
@property
def origin_numel(self) -> int:

View File

@ -1,7 +1,6 @@
from enum import Enum
from typing import Optional
import torch
from colossalai.tensor import ColoTensor
class TensorState(Enum):
@ -12,7 +11,7 @@ class TensorState(Enum):
COMPUTE = 4
class StatefulTensor(ColoTensor):
class StatefulTensor(object):
"""A Structure stores a Torch Tensor and labeled states.
Inspired from the paper:
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
@ -21,20 +20,15 @@ class StatefulTensor(ColoTensor):
"""
def __init__(self, tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
if tensor is not None:
super().__init__(tensor.size(), dtype=tensor.dtype, requires_grad=tensor.requires_grad, \
pin_memory=tensor.pin_memory, torch_tensor=tensor)
else:
super().__init__(0)
self._state = state
self._payload = tensor
if self._state == TensorState.FREE:
assert self.torch_tensor().numel() == 0, f"payload has to None if state is {self._state}"
assert self._payload is None, f"payload has to None if state is {self._state}"
def data_ptr(self):
if self.torch_tensor().numel() == 0:
if self._payload is None:
return None
return self.torch_tensor().data_ptr()
return self._payload.data_ptr()
@property
def state(self) -> TensorState:
@ -42,41 +36,42 @@ class StatefulTensor(ColoTensor):
def set_null(self) -> None:
self._state = TensorState.FREE
self.del_torch_tensor()
self._payload = None
def is_null(self) -> bool:
if self._state == TensorState.FREE:
assert self.torch_tensor().numel() == 0
assert self._payload is None
return True
return False
def trans_state(self, state: TensorState) -> None:
self._state = state
if state == TensorState.FREE:
self.del_torch_tensor()
self._payload = None
@property
def payload(self) -> Optional[torch.Tensor]:
return self.torch_tensor()
return self._payload
def copy_payload(self, tensor) -> None:
self.torch_tensor.view(-1).copy_(tensor.view(-1))
self._payload.view(-1).copy_(tensor.view(-1))
def reset_payload(self, tensor) -> None:
self._torch_tensor = tensor
del self._payload
self._payload = tensor
self.trans_state(TensorState.HOLD)
@property
def device(self) -> torch.device:
return self.torch_tensor().device
return self._payload.device
@property
def dtype(self) -> torch.dtype:
return self.torch_tensor().dtype
return self._payload.dtype
@property
def shape(self):
return self.torch_tensor().shape
return self._payload.shape
def to(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")