mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
210 lines
6.8 KiB
210 lines
6.8 KiB
from enum import Enum
|
|
from typing import Optional
|
|
import torch
|
|
from typing import Union
|
|
|
|
from colossalai.gemini.gemini_context import GeminiMemoryManager
|
|
|
|
|
|
def sizeof_tensor(tensor: torch.Tensor):
|
|
return tensor.numel() * tensor.element_size()
|
|
|
|
|
|
class TensorState(Enum):
|
|
FREE = 0
|
|
HOLD = 1
|
|
HOLD_AFTER_FWD = 2
|
|
HOLD_AFTER_BWD = 3
|
|
COMPUTE = 4
|
|
|
|
|
|
class StatefulTensor(object):
|
|
"""A Structure stores a Torch Tensor and labeled states.
|
|
Inspired from the paper:
|
|
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
|
|
|
https://arxiv.org/abs/2108.05818
|
|
"""
|
|
# Global Stateful Tensor Manager
|
|
GST_MGR = GeminiMemoryManager(TensorState)
|
|
|
|
def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
|
|
self._state = state
|
|
self._payload = None
|
|
self._payload_size = 0 # byte size of current payload
|
|
|
|
StatefulTensor.GST_MGR.register_new_instance()
|
|
|
|
if self._state == TensorState.FREE:
|
|
# when the state is free, payload should be None
|
|
assert maybe_tensor is None, f"payload has to None if state is {self._state}"
|
|
else:
|
|
# otherwise, payload should not be None
|
|
assert maybe_tensor is not None, f"payload can't be None if state is {self._state}"
|
|
self._payload = maybe_tensor
|
|
self._payload_size = sizeof_tensor(maybe_tensor)
|
|
self.__trans_state_update(TensorState.FREE, state)
|
|
|
|
def data_ptr(self):
|
|
if self._payload is None:
|
|
return 0 # if a tensor has no storage, 0 should be returned
|
|
return self._payload.data_ptr()
|
|
|
|
def set_null(self) -> None:
|
|
# notice that free stateful tensor do not need to become null again
|
|
if self.state != TensorState.FREE:
|
|
self.__trans_state_update(self.state, TensorState.FREE)
|
|
self.__release()
|
|
|
|
def is_null(self) -> bool:
|
|
if self.state == TensorState.FREE:
|
|
# check sanity here
|
|
assert self.payload is None
|
|
return True
|
|
return False
|
|
|
|
def trans_state(self, state: TensorState) -> None:
|
|
if self.state == TensorState.FREE:
|
|
# free stateful tensor can't change state
|
|
assert state == TensorState.FREE, "Free stateful tensor can't change to other states"
|
|
return
|
|
|
|
self.__trans_state_update(self.state, state)
|
|
|
|
if state == TensorState.FREE:
|
|
self.__release()
|
|
else:
|
|
self._state = state
|
|
|
|
def move_to(self, device: Union[torch.device, int]):
|
|
assert self.state is not TensorState.FREE, "Can't move free stateful tensor"
|
|
|
|
if not isinstance(device, torch.device):
|
|
to_device = torch.device('cuda', device)
|
|
else:
|
|
to_device = device
|
|
|
|
from_device_type = self.device.type
|
|
if from_device_type == to_device.type:
|
|
# from device == to device
|
|
return
|
|
|
|
# update manager's information
|
|
self.__trans_device_update(from_device_type, to_device.type)
|
|
self.payload.data = self.payload.data.to(to_device)
|
|
|
|
def payload_copy(self, tensor) -> None:
|
|
self._payload.view(-1).copy_(tensor.view(-1))
|
|
|
|
def payload_reset(self, tensor) -> None:
|
|
|
|
assert tensor is not None, "Can't reset None for stateful tensors, please use set_null() instead"
|
|
|
|
if self.payload is not None:
|
|
# release old payload
|
|
self.__trans_state_update(self.state, TensorState.FREE)
|
|
else:
|
|
# otherwise, set the state to HOLD for new payload
|
|
self._state = TensorState.HOLD
|
|
del self._payload
|
|
|
|
self._payload = tensor
|
|
self._payload_size = sizeof_tensor(tensor)
|
|
# record new payload
|
|
self.__trans_state_update(TensorState.FREE, self.state)
|
|
|
|
def payload_relay(self, rhs):
|
|
# relay the payload of rhs to current stateful tensor
|
|
# can't support null relay right now
|
|
assert not rhs.is_null()
|
|
|
|
# now this function only support stateful tensor that has zero-length payload
|
|
# because it doesn't require memory manager updating
|
|
# you can extend this function by yourself
|
|
assert self.payload_size == 0
|
|
|
|
self._payload = rhs.payload
|
|
self._payload_size = rhs.payload_size
|
|
self._state = TensorState.HOLD
|
|
self.__trans_state_update(rhs.state, TensorState.HOLD)
|
|
|
|
rhs.__release()
|
|
|
|
@property
|
|
def payload(self) -> Optional[torch.Tensor]:
|
|
return self._payload
|
|
|
|
@property
|
|
def payload_size(self) -> int:
|
|
return self._payload_size
|
|
|
|
@property
|
|
def state(self) -> TensorState:
|
|
return self._state
|
|
|
|
@property
|
|
def device(self) -> torch.device:
|
|
return self._payload.device
|
|
|
|
@property
|
|
def dtype(self) -> torch.dtype:
|
|
return self._payload.dtype
|
|
|
|
@property
|
|
def shape(self):
|
|
return self._payload.shape
|
|
|
|
def to(self, device: torch.device):
|
|
raise RuntimeError("Use move_to(...) instead of call .to() on StatefulTensor")
|
|
|
|
def to_(self, device: torch.device):
|
|
raise RuntimeError("Use move_to(...) instead of call .to_() on StatefulTensor")
|
|
|
|
def __release(self):
|
|
# release current payload
|
|
# shouldn't be visible to users
|
|
self._state = TensorState.FREE
|
|
self._payload = None
|
|
self._payload_size = 0
|
|
|
|
def __trans_state_update(self, from_state: TensorState, to_state: TensorState):
|
|
"""Update global manager when changing the state of a tensor
|
|
"""
|
|
manager = StatefulTensor.GST_MGR
|
|
size = self.payload_size
|
|
device_type = self.device.type
|
|
|
|
if from_state != TensorState.FREE:
|
|
manager.state_mem[device_type][from_state] -= size
|
|
else:
|
|
# when from_state is FREE, the tensor is new to manager
|
|
# we should add its memory
|
|
manager.total_mem[device_type] += size
|
|
|
|
if to_state != TensorState.FREE:
|
|
manager.state_mem[device_type][to_state] += size
|
|
else:
|
|
# when to_state is FREE, the tensor will be deleted soon
|
|
# we should sub its memory
|
|
manager.total_mem[device_type] -= size
|
|
|
|
def __trans_device_update(self, from_type: str, to_type: str):
|
|
"""Update global manager when changing the device of a tensor
|
|
"""
|
|
manager = StatefulTensor.GST_MGR
|
|
size = self.payload_size
|
|
state = self.state
|
|
|
|
# update aggregated information
|
|
manager.total_mem[from_type] -= size
|
|
manager.total_mem[to_type] += size
|
|
|
|
# update the information of each state
|
|
manager.state_mem[from_type][state] -= size
|
|
manager.state_mem[to_type][state] += size
|
|
|
|
def __del__(self):
|
|
self.set_null()
|
|
StatefulTensor.GST_MGR.delete_instance()
|
|
del self
|