mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
100 lines
4.2 KiB
100 lines
4.2 KiB
import functools |
|
import torch |
|
import types |
|
from colossalai.utils.cuda import get_current_device |
|
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage |
|
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState |
|
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy |
|
from typing import List |
|
from colossalai.logging import get_dist_logger |
|
from time import time |
|
|
|
|
|
class StatefulTensorMgr(object): |
|
""" |
|
Stateful Tensor Manager, inspired from PatrickStar |
|
|
|
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management |
|
https://arxiv.org/abs/2108.05818 |
|
""" |
|
|
|
def __init__(self, tensor_placement_policy: TensorPlacementPolicy) -> None: |
|
self._tensor_placement_policy: TensorPlacementPolicy = tensor_placement_policy |
|
self._stateful_tensor_list: List[StatefulTensor] = [] |
|
|
|
self._compute_list: List[StatefulTensor] = [] |
|
self._compute_idx: int = -1 |
|
|
|
self._cpu_gpu_move_volume = 0 |
|
self._layout_time = 0 |
|
self._evict_time = 0 |
|
self._warmup = True |
|
|
|
def register_stateful_tensor_list(self, tensor_list: List[StatefulTensor]) -> None: |
|
assert self._stateful_tensor_list == [], "Can't register stateful tensors for manager twice" |
|
self._stateful_tensor_list = tensor_list |
|
for t in self._stateful_tensor_list: |
|
assert isinstance(t, StatefulTensor) |
|
t.trans_state = types.MethodType(functools.partial(self._trans_state, t.trans_state), t) |
|
|
|
def start_iter(self): |
|
pass |
|
|
|
def finish_iter(self): |
|
"""This function must be called when each iteration finishes |
|
""" |
|
self._warmup = False |
|
self._compute_idx = -1 |
|
self._cpu_gpu_move_volume = 0 |
|
self._layout_time = 0 |
|
self._evict_time = 0 |
|
|
|
def adjust_layout(self) -> None: |
|
""" Adjust the layout of statefuil tensor according to the information provided |
|
by mem_stats_collector, which should belongs to a Sharded Model. |
|
""" |
|
# find stateful tensor in state COMPUTE |
|
cuda_demand = StatefulTensor.GST_MGR.state_mem['cpu'][TensorState.COMPUTE] |
|
start = time() |
|
move_to_cuda_tensor_list, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup) |
|
self._layout_time += time() - start |
|
vol, evict_time = self._tensor_placement_policy.evict_tensors(hold_cuda_tensor_list, |
|
cuda_demand=cuda_demand, |
|
warmup=self._warmup, |
|
compute_list=self._compute_list, |
|
compute_idx=self._compute_idx) |
|
self._cpu_gpu_move_volume += vol |
|
self._evict_time += evict_time |
|
# move COMPUTE tensors to CUDA |
|
self._cpu_gpu_move_volume += cuda_demand |
|
for t in move_to_cuda_tensor_list: |
|
colo_model_data_tensor_move_inline(t, get_current_device()) |
|
|
|
@property |
|
def cpu_gpu_move_volume(self): |
|
return self._cpu_gpu_move_volume |
|
|
|
def _trans_state(self, trans_state_func, stateful_tensor, state): |
|
trans_state_func(state) |
|
if state == TensorState.COMPUTE: |
|
self._compute_idx += 1 |
|
if self._warmup: |
|
self._compute_list.append(stateful_tensor) |
|
|
|
@functools.lru_cache(maxsize=None) |
|
def _get_layout_info(self, compute_idx: int, warmup: bool): |
|
move_to_cuda_tensor_list = [] |
|
hold_cuda_tensor_list = [] |
|
for tensor in self._stateful_tensor_list: |
|
if tensor.state == TensorState.FREE: |
|
continue |
|
|
|
if tensor.device.type == 'cuda': |
|
if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]: |
|
hold_cuda_tensor_list.append(tensor) |
|
elif tensor.device.type == 'cpu': |
|
if tensor.state == TensorState.COMPUTE: |
|
move_to_cuda_tensor_list.append(tensor) |
|
else: |
|
raise RuntimeError |
|
return move_to_cuda_tensor_list, hold_cuda_tensor_list
|
|
|