import functools import torch import types from colossalai.utils.cuda import get_current_device from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy from typing import List from colossalai.logging import get_dist_logger from time import time class StatefulTensorMgr(object): """ Stateful Tensor Manager, inspired from PatrickStar PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management https://arxiv.org/abs/2108.05818 """ def __init__(self, tensor_placement_policy: TensorPlacementPolicy) -> None: self._tensor_placement_policy: TensorPlacementPolicy = tensor_placement_policy self._stateful_tensor_list: List[StatefulTensor] = [] self._compute_list: List[StatefulTensor] = [] self._compute_idx: int = -1 self._cpu_gpu_move_volume = 0 self._layout_time = 0 self._evict_time = 0 self._warmup = True def register_stateful_tensor_list(self, tensor_list: List[StatefulTensor]) -> None: assert self._stateful_tensor_list == [], "Can't register stateful tensors for manager twice" self._stateful_tensor_list = tensor_list for t in self._stateful_tensor_list: assert isinstance(t, StatefulTensor) t.trans_state = types.MethodType(functools.partial(self._trans_state, t.trans_state), t) def start_iter(self): pass def finish_iter(self): """This function must be called when each iteration finishes """ self._warmup = False self._compute_idx = -1 self._cpu_gpu_move_volume = 0 self._layout_time = 0 self._evict_time = 0 def adjust_layout(self) -> None: """ Adjust the layout of statefuil tensor according to the information provided by mem_stats_collector, which should belongs to a Sharded Model. """ # find stateful tensor in state COMPUTE cuda_demand = StatefulTensor.GST_MGR.state_mem['cpu'][TensorState.COMPUTE] start = time() move_to_cuda_tensor_list, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup) self._layout_time += time() - start vol, evict_time = self._tensor_placement_policy.evict_tensors(hold_cuda_tensor_list, cuda_demand=cuda_demand, warmup=self._warmup, compute_list=self._compute_list, compute_idx=self._compute_idx) self._cpu_gpu_move_volume += vol self._evict_time += evict_time # move COMPUTE tensors to CUDA self._cpu_gpu_move_volume += cuda_demand for t in move_to_cuda_tensor_list: colo_model_data_tensor_move_inline(t, get_current_device()) @property def cpu_gpu_move_volume(self): return self._cpu_gpu_move_volume def _trans_state(self, trans_state_func, stateful_tensor, state): trans_state_func(state) if state == TensorState.COMPUTE: self._compute_idx += 1 if self._warmup: self._compute_list.append(stateful_tensor) @functools.lru_cache(maxsize=None) def _get_layout_info(self, compute_idx: int, warmup: bool): move_to_cuda_tensor_list = [] hold_cuda_tensor_list = [] for tensor in self._stateful_tensor_list: if tensor.state == TensorState.FREE: continue if tensor.device.type == 'cuda': if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]: hold_cuda_tensor_list.append(tensor) elif tensor.device.type == 'cpu': if tensor.state == TensorState.COMPUTE: move_to_cuda_tensor_list.append(tensor) else: raise RuntimeError return move_to_cuda_tensor_list, hold_cuda_tensor_list