You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/zero/sharded_param/tensorful_state.py

86 lines
2.5 KiB

from enum import Enum
from typing import Optional
import torch
from colossalai.tensor import ColoTensor
class TensorState(Enum):
FREE = 0
HOLD = 1
HOLD_AFTER_FWD = 2
HOLD_AFTER_BWD = 3
COMPUTE = 4
class StatefulTensor(ColoTensor):
"""A Structure stores a Torch Tensor and labeled states.
Inspired from the paper:
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
"""
def __init__(self, tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
if tensor is not None:
super().__init__(tensor.size(), dtype=tensor.dtype, requires_grad=tensor.requires_grad, \
pin_memory=tensor.pin_memory, torch_tensor=tensor)
else:
super().__init__(0)
self._state = state
if self._state == TensorState.FREE:
assert self.torch_tensor().numel() == 0, f"payload has to None if state is {self._state}"
def data_ptr(self):
if self.torch_tensor().numel() == 0:
return None
return self.torch_tensor().data_ptr()
@property
def state(self) -> TensorState:
return self._state
def set_null(self) -> None:
self._state = TensorState.FREE
self.del_torch_tensor()
def is_null(self) -> bool:
if self._state == TensorState.FREE:
assert self.torch_tensor().numel() == 0
return True
return False
def trans_state(self, state: TensorState) -> None:
self._state = state
if state == TensorState.FREE:
self.del_torch_tensor()
@property
def payload(self) -> Optional[torch.Tensor]:
return self.torch_tensor()
def copy_payload(self, tensor) -> None:
self.torch_tensor.view(-1).copy_(tensor.view(-1))
def reset_payload(self, tensor) -> None:
self._torch_tensor = tensor
self.trans_state(TensorState.HOLD)
@property
def device(self) -> torch.device:
return self.torch_tensor().device
@property
def dtype(self) -> torch.dtype:
return self.torch_tensor().dtype
@property
def shape(self):
return self.torch_tensor().shape
def to(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
def to_(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor")