mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
101 lines
4.2 KiB
101 lines
4.2 KiB
import functools
|
|
import torch
|
|
import types
|
|
from colossalai.utils.cuda import get_current_device
|
|
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
|
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
|
|
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy
|
|
from typing import List
|
|
from colossalai.logging import get_dist_logger
|
|
from time import time
|
|
|
|
|
|
class StatefulTensorMgr(object):
|
|
"""
|
|
Stateful Tensor Manager, inspired from PatrickStar
|
|
|
|
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
|
https://arxiv.org/abs/2108.05818
|
|
"""
|
|
|
|
def __init__(self, tensor_placement_policy: TensorPlacementPolicy) -> None:
|
|
self._tensor_placement_policy: TensorPlacementPolicy = tensor_placement_policy
|
|
self._stateful_tensor_list: List[StatefulTensor] = []
|
|
|
|
self._compute_list: List[StatefulTensor] = []
|
|
self._compute_idx: int = -1
|
|
|
|
self._cpu_gpu_move_volume = 0
|
|
self._layout_time = 0
|
|
self._evict_time = 0
|
|
self._warmup = True
|
|
|
|
def register_stateful_tensor_list(self, tensor_list: List[StatefulTensor]) -> None:
|
|
assert self._stateful_tensor_list == [], "Can't register stateful tensors for manager twice"
|
|
self._stateful_tensor_list = tensor_list
|
|
for t in self._stateful_tensor_list:
|
|
assert isinstance(t, StatefulTensor)
|
|
t.trans_state = types.MethodType(functools.partial(self._trans_state, t.trans_state), t)
|
|
|
|
def start_iter(self):
|
|
pass
|
|
|
|
def finish_iter(self):
|
|
"""This function must be called when each iteration finishes
|
|
"""
|
|
self._warmup = False
|
|
self._compute_idx = -1
|
|
self._cpu_gpu_move_volume = 0
|
|
self._layout_time = 0
|
|
self._evict_time = 0
|
|
|
|
def adjust_layout(self) -> None:
|
|
""" Adjust the layout of statefuil tensor according to the information provided
|
|
by mem_stats_collector, which should belongs to a Sharded Model.
|
|
"""
|
|
# find stateful tensor in state COMPUTE
|
|
cuda_demand = StatefulTensor.GST_MGR.state_mem['cpu'][TensorState.COMPUTE]
|
|
start = time()
|
|
move_to_cuda_tensor_list, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup)
|
|
self._layout_time += time() - start
|
|
vol, evict_time = self._tensor_placement_policy.evict_tensors(hold_cuda_tensor_list,
|
|
cuda_demand=cuda_demand,
|
|
warmup=self._warmup,
|
|
compute_list=self._compute_list,
|
|
compute_idx=self._compute_idx)
|
|
self._cpu_gpu_move_volume += vol
|
|
self._evict_time += evict_time
|
|
# move COMPUTE tensors to CUDA
|
|
self._cpu_gpu_move_volume += cuda_demand
|
|
for t in move_to_cuda_tensor_list:
|
|
colo_model_data_tensor_move_inline(t, get_current_device())
|
|
|
|
@property
|
|
def cpu_gpu_move_volume(self):
|
|
return self._cpu_gpu_move_volume
|
|
|
|
def _trans_state(self, trans_state_func, stateful_tensor, state):
|
|
trans_state_func(state)
|
|
if state == TensorState.COMPUTE:
|
|
self._compute_idx += 1
|
|
if self._warmup:
|
|
self._compute_list.append(stateful_tensor)
|
|
|
|
@functools.lru_cache(maxsize=None)
|
|
def _get_layout_info(self, compute_idx: int, warmup: bool):
|
|
move_to_cuda_tensor_list = []
|
|
hold_cuda_tensor_list = []
|
|
for tensor in self._stateful_tensor_list:
|
|
if tensor.state == TensorState.FREE:
|
|
continue
|
|
|
|
if tensor.device.type == 'cuda':
|
|
if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]:
|
|
hold_cuda_tensor_list.append(tensor)
|
|
elif tensor.device.type == 'cpu':
|
|
if tensor.state == TensorState.COMPUTE:
|
|
move_to_cuda_tensor_list.append(tensor)
|
|
else:
|
|
raise RuntimeError
|
|
return move_to_cuda_tensor_list, hold_cuda_tensor_list
|