mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
209 lines
6.8 KiB
209 lines
6.8 KiB
from enum import Enum |
|
from typing import Optional |
|
import torch |
|
from typing import Union |
|
|
|
from colossalai.gemini.gemini_context import GeminiMemoryManager |
|
|
|
|
|
def sizeof_tensor(tensor: torch.Tensor): |
|
return tensor.numel() * tensor.element_size() |
|
|
|
|
|
class TensorState(Enum): |
|
FREE = 0 |
|
HOLD = 1 |
|
HOLD_AFTER_FWD = 2 |
|
HOLD_AFTER_BWD = 3 |
|
COMPUTE = 4 |
|
|
|
|
|
class StatefulTensor(object): |
|
"""A Structure stores a Torch Tensor and labeled states. |
|
Inspired from the paper: |
|
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management |
|
|
|
https://arxiv.org/abs/2108.05818 |
|
""" |
|
# Global Stateful Tensor Manager |
|
GST_MGR = GeminiMemoryManager(TensorState) |
|
|
|
def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None: |
|
self._state = state |
|
self._payload = None |
|
self._payload_size = 0 # byte size of current payload |
|
|
|
StatefulTensor.GST_MGR.register_new_instance() |
|
|
|
if self._state == TensorState.FREE: |
|
# when the state is free, payload should be None |
|
assert maybe_tensor is None, f"payload has to None if state is {self._state}" |
|
else: |
|
# otherwise, payload should not be None |
|
assert maybe_tensor is not None, f"payload can't be None if state is {self._state}" |
|
self._payload = maybe_tensor |
|
self._payload_size = sizeof_tensor(maybe_tensor) |
|
self.__trans_state_update(TensorState.FREE, state) |
|
|
|
def data_ptr(self): |
|
if self._payload is None: |
|
return 0 # if a tensor has no storage, 0 should be returned |
|
return self._payload.data_ptr() |
|
|
|
def set_null(self) -> None: |
|
# notice that free stateful tensor do not need to become null again |
|
if self.state != TensorState.FREE: |
|
self.__trans_state_update(self.state, TensorState.FREE) |
|
self.__release() |
|
|
|
def is_null(self) -> bool: |
|
if self.state == TensorState.FREE: |
|
# check sanity here |
|
assert self.payload is None |
|
return True |
|
return False |
|
|
|
def trans_state(self, state: TensorState) -> None: |
|
if self.state == TensorState.FREE: |
|
# free stateful tensor can't change state |
|
assert state == TensorState.FREE, "Free stateful tensor can't change to other states" |
|
return |
|
|
|
self.__trans_state_update(self.state, state) |
|
|
|
if state == TensorState.FREE: |
|
self.__release() |
|
else: |
|
self._state = state |
|
|
|
def move_to(self, device: Union[torch.device, int]): |
|
assert self.state is not TensorState.FREE, "Can't move free stateful tensor" |
|
|
|
if not isinstance(device, torch.device): |
|
to_device = torch.device('cuda', device) |
|
else: |
|
to_device = device |
|
|
|
from_device_type = self.device.type |
|
if from_device_type == to_device.type: |
|
# from device == to device |
|
return |
|
|
|
# update manager's information |
|
self.__trans_device_update(from_device_type, to_device.type) |
|
self.payload.data = self.payload.data.to(to_device) |
|
|
|
def payload_copy(self, tensor) -> None: |
|
self._payload.view(-1).copy_(tensor.view(-1)) |
|
|
|
def payload_reset(self, tensor) -> None: |
|
|
|
assert tensor is not None, "Can't reset None for stateful tensors, please use set_null() instead" |
|
|
|
if self.payload is not None: |
|
# release old payload |
|
self.__trans_state_update(self.state, TensorState.FREE) |
|
else: |
|
# otherwise, set the state to HOLD for new payload |
|
self._state = TensorState.HOLD |
|
del self._payload |
|
|
|
self._payload = tensor |
|
self._payload_size = sizeof_tensor(tensor) |
|
# record new payload |
|
self.__trans_state_update(TensorState.FREE, self.state) |
|
|
|
def payload_relay(self, rhs): |
|
# relay the payload of rhs to current stateful tensor |
|
# can't support null relay right now |
|
assert not rhs.is_null() |
|
|
|
# now this function only support stateful tensor that has zero-length payload |
|
# because it doesn't require memory manager updating |
|
# you can extend this function by yourself |
|
assert self.payload_size == 0 |
|
|
|
self._payload = rhs.payload |
|
self._payload_size = rhs.payload_size |
|
self._state = TensorState.HOLD |
|
self.__trans_state_update(rhs.state, TensorState.HOLD) |
|
|
|
rhs.__release() |
|
|
|
@property |
|
def payload(self) -> Optional[torch.Tensor]: |
|
return self._payload |
|
|
|
@property |
|
def payload_size(self) -> int: |
|
return self._payload_size |
|
|
|
@property |
|
def state(self) -> TensorState: |
|
return self._state |
|
|
|
@property |
|
def device(self) -> torch.device: |
|
return self._payload.device |
|
|
|
@property |
|
def dtype(self) -> torch.dtype: |
|
return self._payload.dtype |
|
|
|
@property |
|
def shape(self): |
|
return self._payload.shape |
|
|
|
def to(self, device: torch.device): |
|
raise RuntimeError("Use move_to(...) instead of call .to() on StatefulTensor") |
|
|
|
def to_(self, device: torch.device): |
|
raise RuntimeError("Use move_to(...) instead of call .to_() on StatefulTensor") |
|
|
|
def __release(self): |
|
# release current payload |
|
# shouldn't be visible to users |
|
self._state = TensorState.FREE |
|
self._payload = None |
|
self._payload_size = 0 |
|
|
|
def __trans_state_update(self, from_state: TensorState, to_state: TensorState): |
|
"""Update global manager when changing the state of a tensor |
|
""" |
|
manager = StatefulTensor.GST_MGR |
|
size = self.payload_size |
|
device_type = self.device.type |
|
|
|
if from_state != TensorState.FREE: |
|
manager.state_mem[device_type][from_state] -= size |
|
else: |
|
# when from_state is FREE, the tensor is new to manager |
|
# we should add its memory |
|
manager.total_mem[device_type] += size |
|
|
|
if to_state != TensorState.FREE: |
|
manager.state_mem[device_type][to_state] += size |
|
else: |
|
# when to_state is FREE, the tensor will be deleted soon |
|
# we should sub its memory |
|
manager.total_mem[device_type] -= size |
|
|
|
def __trans_device_update(self, from_type: str, to_type: str): |
|
"""Update global manager when changing the device of a tensor |
|
""" |
|
manager = StatefulTensor.GST_MGR |
|
size = self.payload_size |
|
state = self.state |
|
|
|
# update aggregated information |
|
manager.total_mem[from_type] -= size |
|
manager.total_mem[to_type] += size |
|
|
|
# update the information of each state |
|
manager.state_mem[from_type][state] -= size |
|
manager.state_mem[to_type][state] += size |
|
|
|
def __del__(self): |
|
self.set_null() |
|
StatefulTensor.GST_MGR.delete_instance() |
|
del self
|
|
|