2022-04-08 09:51:34 +00:00
|
|
|
import functools
|
2022-04-06 08:18:49 +00:00
|
|
|
import torch
|
2022-04-08 09:51:34 +00:00
|
|
|
import types
|
2022-04-06 08:18:49 +00:00
|
|
|
from colossalai.utils.cuda import get_current_device
|
|
|
|
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
|
|
|
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
|
|
|
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
|
|
|
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
|
2022-04-11 07:10:39 +00:00
|
|
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
2022-04-08 09:51:34 +00:00
|
|
|
from typing import Dict, List
|
2022-04-06 08:18:49 +00:00
|
|
|
from colossalai.utils.memory_tracer import MemStatsCollector
|
2022-04-08 09:51:34 +00:00
|
|
|
from colossalai.logging import get_dist_logger
|
2022-04-06 08:18:49 +00:00
|
|
|
|
|
|
|
|
2022-04-08 09:51:34 +00:00
|
|
|
class StatefulTensorMgr(object):
|
|
|
|
"""
|
|
|
|
Stateful Tensor Manager, inspired from PatrickStar
|
2022-04-06 08:18:49 +00:00
|
|
|
|
2022-04-08 09:51:34 +00:00
|
|
|
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
|
|
|
https://arxiv.org/abs/2108.05818
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, mem_stats_collector: MemStatsCollector) -> None:
|
|
|
|
self._stateful_tensor_list: List[StatefulTensor] = []
|
|
|
|
self._mem_stats_collector = mem_stats_collector
|
|
|
|
self._logger = get_dist_logger("StatefulTensorMgr")
|
|
|
|
|
|
|
|
self._warmup = True
|
|
|
|
self._warmup_cuda_available_ratio = 0.2
|
|
|
|
|
|
|
|
self._compute_list: List[StatefulTensor] = []
|
|
|
|
self._compute_idx: int = -1
|
|
|
|
|
|
|
|
def register_stateful_param(self, param: ShardedParamV2) -> None:
|
|
|
|
assert isinstance(param, ShardedParamV2)
|
2022-04-06 08:18:49 +00:00
|
|
|
for t in param.get_payload_tensors():
|
|
|
|
assert isinstance(t, StatefulTensor)
|
2022-04-08 09:51:34 +00:00
|
|
|
self._stateful_tensor_list.append(t)
|
|
|
|
t.trans_state = types.MethodType(functools.partial(self._trans_state, t.trans_state), t)
|
2022-04-06 08:18:49 +00:00
|
|
|
|
2022-04-08 09:51:34 +00:00
|
|
|
def adjust_layout(self) -> None:
|
2022-04-06 08:18:49 +00:00
|
|
|
""" Adjust the layout of statefuil tensor according to the information provided
|
|
|
|
by mem_stats_collector, which should belongs to a Sharded Model.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
mem_stats_collector (MemStatsCollector): a collector, usually owned by a Sharded Model.
|
|
|
|
It contains non-model footprint of a DNN model.
|
|
|
|
"""
|
|
|
|
# find stateful tensor in state COMPUTE
|
|
|
|
move_to_cuda_tensor_list = []
|
|
|
|
cuda_demand = 0
|
2022-04-11 07:10:39 +00:00
|
|
|
used_cuda_model_data = GLOBAL_MODEL_DATA_TRACER.cuda_usage
|
2022-04-06 08:18:49 +00:00
|
|
|
hold_cuda_tensor_list = []
|
|
|
|
for tensor in self._stateful_tensor_list:
|
|
|
|
if tensor.state == TensorState.FREE:
|
|
|
|
continue
|
|
|
|
|
|
|
|
if tensor.device.type == 'cuda':
|
|
|
|
if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]:
|
|
|
|
hold_cuda_tensor_list.append(tensor)
|
2022-04-08 09:51:34 +00:00
|
|
|
elif tensor.device.type == 'cpu':
|
2022-04-06 08:18:49 +00:00
|
|
|
if tensor.state == TensorState.COMPUTE:
|
|
|
|
move_to_cuda_tensor_list.append(tensor)
|
2022-04-08 09:51:34 +00:00
|
|
|
cuda_demand += colo_tensor_mem_usage(tensor.payload)[1]
|
|
|
|
else:
|
|
|
|
raise RuntimeError
|
2022-04-06 08:18:49 +00:00
|
|
|
cuda_capacity = colo_cuda_memory_capacity()
|
|
|
|
|
2022-04-08 09:51:34 +00:00
|
|
|
if self._warmup:
|
|
|
|
# We designate a part of CUDA memory for model data in warmup iterations.
|
|
|
|
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_cuda_available_ratio
|
|
|
|
else:
|
|
|
|
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
|
2022-04-11 02:46:08 +00:00
|
|
|
max_cuda_non_model_data_per_period = self._mem_stats_collector.max_non_model_data('cuda')
|
2022-04-08 09:51:34 +00:00
|
|
|
|
|
|
|
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
|
|
|
|
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
|
|
|
|
|
|
|
if avail_cuda_model_data < cuda_demand:
|
|
|
|
# Move cuda_demand - avail_cuda_model_data volume of tensors
|
|
|
|
# to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
|
|
|
|
self.evict_tensors(hold_cuda_tensor_list, cuda_demand - avail_cuda_model_data)
|
2022-04-06 08:18:49 +00:00
|
|
|
# move COMPUTE tensors to CUDA
|
|
|
|
for t in move_to_cuda_tensor_list:
|
|
|
|
colo_model_data_tensor_move_inline(t, get_current_device())
|
2022-04-08 09:51:34 +00:00
|
|
|
|
|
|
|
def reset(self):
|
|
|
|
"""This function must be called when each iteration finishes
|
|
|
|
"""
|
|
|
|
self._warmup = False
|
|
|
|
self._compute_idx = -1
|
|
|
|
|
|
|
|
def evict_tensors(self, hold_cuda_tensor_list, to_free_cuda_model_data):
|
|
|
|
freed_cuda_model_data = 0
|
|
|
|
to_free_tensor_list = hold_cuda_tensor_list
|
|
|
|
if not self._warmup:
|
|
|
|
next_compute_idx: Dict[StatefulTensor, int] = {t: len(self._compute_list) for t in hold_cuda_tensor_list}
|
|
|
|
for i in range(len(self._compute_list) - 1, self._compute_idx, -1):
|
|
|
|
if self._compute_list[i] in next_compute_idx:
|
|
|
|
next_compute_idx[self._compute_list[i]] = i
|
|
|
|
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
|
|
|
|
to_free_tensor_list = [t for (t, idx) in next_compute_idx]
|
|
|
|
for t in to_free_tensor_list:
|
|
|
|
if freed_cuda_model_data > to_free_cuda_model_data:
|
|
|
|
break
|
|
|
|
freed_cuda_model_data += colo_tensor_mem_usage(t)[0]
|
|
|
|
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
|
|
|
|
if freed_cuda_model_data < to_free_cuda_model_data:
|
|
|
|
raise RuntimeError(
|
|
|
|
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
|
|
|
|
)
|
|
|
|
|
|
|
|
def _trans_state(self, trans_state_func, stateful_tensor, state):
|
|
|
|
trans_state_func(state)
|
|
|
|
if state == TensorState.COMPUTE:
|
|
|
|
self._compute_idx += 1
|
|
|
|
if self._warmup:
|
|
|
|
self._compute_list.append(stateful_tensor)
|