diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 6ed82aea9..f40034dc1 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -15,12 +15,12 @@ class ColoTensor(object): return super(ColoTensor, cls).__new__(cls) def __init__( - self, - *size: Tuple[int], - dtype=None, - requires_grad=False, - pin_memory=False, - torch_tensor=None, + self, + *size: Tuple[int], + dtype=None, + requires_grad=False, + pin_memory=False, + torch_tensor=torch.empty(0), ): self._size = size self._dtype = dtype @@ -37,8 +37,13 @@ class ColoTensor(object): torch_tensor=tensor) return colo_t + def del_torch_tensor(self) -> None: + self._size = (0,) + self._torch_tensor = torch.empty(self._size) + def torch_tensor(self) -> torch.Tensor: - if self._torch_tensor == None: + if self._torch_tensor == None or self._torch_tensor.numel() == 0: + print(self._size, type(self._size)) self._torch_tensor = torch.empty(*self._size, dtype=self._dtype, requires_grad=self._requires_grad, diff --git a/colossalai/zero/sharded_param/sharded_tensor.py b/colossalai/zero/sharded_param/sharded_tensor.py index fde273320..e1f48b318 100644 --- a/colossalai/zero/sharded_param/sharded_tensor.py +++ b/colossalai/zero/sharded_param/sharded_tensor.py @@ -20,8 +20,8 @@ class ShardedTensor(StatefulTensor): @property def dtype(self) -> torch.dtype: - assert self._payload.dtype == self._origin_dtype - return self._payload.dtype + assert self.torch_tensor().dtype == self._origin_dtype + return self.torch_tensor().dtype @property def origin_numel(self) -> int: diff --git a/colossalai/zero/sharded_param/tensorful_state.py b/colossalai/zero/sharded_param/tensorful_state.py index a108963e5..d62f85b0e 100644 --- a/colossalai/zero/sharded_param/tensorful_state.py +++ b/colossalai/zero/sharded_param/tensorful_state.py @@ -1,6 +1,7 @@ from enum import Enum from typing import Optional import torch +from colossalai.tensor import ColoTensor class TensorState(Enum): @@ -11,7 +12,7 @@ class TensorState(Enum): COMPUTE = 4 -class StatefulTensor(object): +class StatefulTensor(ColoTensor): """A Structure stores a Torch Tensor and labeled states. Inspired from the paper: PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management @@ -20,15 +21,20 @@ class StatefulTensor(object): """ def __init__(self, tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None: + if tensor is not None: + super().__init__(tensor.size(), dtype=tensor.dtype, requires_grad=tensor.requires_grad, \ + pin_memory=tensor.pin_memory, torch_tensor=tensor) + else: + super().__init__(0) + self._state = state - self._payload = tensor if self._state == TensorState.FREE: - assert self._payload is None, f"payload has to None if state is {self._state}" + assert self.torch_tensor().numel() == 0, f"payload has to None if state is {self._state}" def data_ptr(self): - if self._payload is None: + if self.torch_tensor().numel() == 0: return None - return self._payload.data_ptr() + return self.torch_tensor().data_ptr() @property def state(self) -> TensorState: @@ -36,42 +42,41 @@ class StatefulTensor(object): def set_null(self) -> None: self._state = TensorState.FREE - self._payload = None + self.del_torch_tensor() def is_null(self) -> bool: if self._state == TensorState.FREE: - assert self._payload is None + assert self.torch_tensor().numel() == 0 return True return False def trans_state(self, state: TensorState) -> None: self._state = state if state == TensorState.FREE: - self._payload = None + self.del_torch_tensor() @property def payload(self) -> Optional[torch.Tensor]: - return self._payload + return self.torch_tensor() def copy_payload(self, tensor) -> None: - self._payload.view(-1).copy_(tensor.view(-1)) + self.torch_tensor.view(-1).copy_(tensor.view(-1)) def reset_payload(self, tensor) -> None: - del self._payload - self._payload = tensor + self._torch_tensor = tensor self.trans_state(TensorState.HOLD) @property def device(self) -> torch.device: - return self._payload.device + return self.torch_tensor().device @property def dtype(self) -> torch.dtype: - return self._payload.dtype + return self.torch_tensor().dtype @property def shape(self): - return self._payload.shape + return self.torch_tensor().shape def to(self, device: torch.device): raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor") diff --git a/tests/test_tensor/test_op.py b/tests/test_tensor/test_op.py index 6cd45df44..71ce01dd6 100644 --- a/tests/test_tensor/test_op.py +++ b/tests/test_tensor/test_op.py @@ -60,8 +60,8 @@ def test_no_wrap_op(): assert torch.sum(input=t) == torch.sum(input=t_ref) def test_lazy_init_tensor(): - lazy_t = ColoTensor((2, 3), dtype=torch.float32, requires_grad=True) - assert lazy_t._torch_tensor == None + lazy_t = ColoTensor(2, 3, dtype=torch.float32, requires_grad=True) + assert lazy_t._torch_tensor.numel() == 0 assert lazy_t.torch_tensor().numel() == 6 def check_all():