ColossalAI/colossalai/zero/sharded_param/tensorful_state.py

81 lines
2.2 KiB
Python
Raw Normal View History

2022-03-30 05:51:37 +00:00
from enum import Enum
from typing import Optional
2022-03-30 05:51:37 +00:00
import torch
class TensorState(Enum):
FREE = 0
HOLD = 1
HOLD_AFTER_FWD = 2
HOLD_AFTER_BWD = 3
COMPUTE = 4
2022-03-30 05:51:37 +00:00
2022-04-22 04:12:35 +00:00
class StatefulTensor(object):
"""A Structure stores a Torch Tensor and labeled states.
Inspired from the paper:
2022-03-30 05:51:37 +00:00
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
"""
def __init__(self, tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
2022-03-30 05:51:37 +00:00
self._state = state
2022-04-22 04:12:35 +00:00
self._payload = tensor
if self._state == TensorState.FREE:
2022-04-22 04:12:35 +00:00
assert self._payload is None, f"payload has to None if state is {self._state}"
2022-03-30 05:51:37 +00:00
def data_ptr(self):
2022-04-22 04:12:35 +00:00
if self._payload is None:
2022-03-30 05:51:37 +00:00
return None
2022-04-22 04:12:35 +00:00
return self._payload.data_ptr()
2022-03-30 05:51:37 +00:00
@property
def state(self) -> TensorState:
return self._state
def set_null(self) -> None:
self._state = TensorState.FREE
2022-04-22 04:12:35 +00:00
self._payload = None
2022-03-30 05:51:37 +00:00
def is_null(self) -> bool:
if self._state == TensorState.FREE:
2022-04-22 04:12:35 +00:00
assert self._payload is None
2022-03-30 05:51:37 +00:00
return True
return False
def trans_state(self, state: TensorState) -> None:
self._state = state
if state == TensorState.FREE:
2022-04-22 04:12:35 +00:00
self._payload = None
2022-03-30 05:51:37 +00:00
@property
def payload(self) -> Optional[torch.Tensor]:
2022-04-22 04:12:35 +00:00
return self._payload
2022-03-30 05:51:37 +00:00
def copy_payload(self, tensor) -> None:
2022-04-22 04:12:35 +00:00
self._payload.view(-1).copy_(tensor.view(-1))
2022-03-30 05:51:37 +00:00
def reset_payload(self, tensor) -> None:
2022-04-22 04:12:35 +00:00
del self._payload
self._payload = tensor
2022-03-30 05:51:37 +00:00
self.trans_state(TensorState.HOLD)
@property
def device(self) -> torch.device:
2022-04-22 04:12:35 +00:00
return self._payload.device
2022-03-30 05:51:37 +00:00
@property
def dtype(self) -> torch.dtype:
2022-04-22 04:12:35 +00:00
return self._payload.dtype
@property
def shape(self):
2022-04-22 04:12:35 +00:00
return self._payload.shape
2022-03-30 05:51:37 +00:00
def to(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
def to_(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor")