mirror of https://github.com/hpcaitech/ColossalAI
[zero] initialize a stateful tensor manager (#614)
parent
cc236916c6
commit
59bf2dc590
|
@ -1,3 +1,4 @@
|
|||
from .async_memtracer import AsyncMemoryMonitor
|
||||
from .memstats_collector import MemStatsCollector
|
||||
|
||||
__all__ = ['AsyncMemoryMonitor']
|
||||
__all__ = ['AsyncMemoryMonitor', 'MemStatsCollector']
|
||||
|
|
|
@ -11,15 +11,21 @@ class SamplingCounter:
|
|||
|
||||
def __init__(self) -> None:
|
||||
self._samplint_cnt = 0
|
||||
self._max_sampling_cnt = None
|
||||
|
||||
def advance(self):
|
||||
self._samplint_cnt += 1
|
||||
|
||||
def next(self):
|
||||
assert self._max_sampling_cnt is not None
|
||||
return (self._samplint_cnt + 1) % self._max_sampling_cnt
|
||||
|
||||
@property
|
||||
def sampling_cnt(self):
|
||||
return self._samplint_cnt
|
||||
|
||||
def reset(self):
|
||||
self._max_sampling_cnt = self._samplint_cnt
|
||||
self._samplint_cnt = 0
|
||||
|
||||
|
||||
|
@ -56,7 +62,7 @@ class MemStatsCollector:
|
|||
else:
|
||||
raise TypeError
|
||||
|
||||
def model_data_cuda_list(self, device_type: str, unit: str = 'B') -> List[int]:
|
||||
def model_data_list(self, device_type: str, unit: str = 'B') -> List[int]:
|
||||
if unit == 'GB':
|
||||
scale = 1e9
|
||||
elif unit == 'MB':
|
||||
|
@ -75,7 +81,7 @@ class MemStatsCollector:
|
|||
else:
|
||||
raise TypeError
|
||||
|
||||
def non_model_data_cuda_list(self, device_type: str, unit: str = 'B') -> List[int]:
|
||||
def non_model_data_list(self, device_type: str, unit: str = 'B') -> List[int]:
|
||||
"""Non model data stats
|
||||
"""
|
||||
if unit == 'GB':
|
||||
|
@ -96,6 +102,14 @@ class MemStatsCollector:
|
|||
else:
|
||||
raise TypeError
|
||||
|
||||
def current_non_model_data(self, device_type: str) -> int:
|
||||
"""get the non model data of current sampling moment
|
||||
"""
|
||||
return self.non_model_data_list(device_type)[self._sampling_cnter.sampling_cnt]
|
||||
|
||||
def next_non_model_data(self, device_type: str):
|
||||
return self.non_model_data_list(device_type)[self._sampling_cnter.next()]
|
||||
|
||||
@property
|
||||
def sampling_time(self):
|
||||
return [t - self._sampling_time[0] for t in self._sampling_time]
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
import torch
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
||||
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
|
||||
from typing import Set
|
||||
from colossalai.utils.memory_tracer import MemStatsCollector
|
||||
|
||||
|
||||
class StatefulTensorMgr(SingletonMeta):
|
||||
_stateful_tensor_list: Set[ShardedParamV2] = set()
|
||||
|
||||
def register_param(self, param: ShardedParamV2) -> None:
|
||||
for t in param.get_payload_tensors():
|
||||
assert isinstance(t, StatefulTensor)
|
||||
self._stateful_tensor_list.add(t)
|
||||
|
||||
def evict_tensors(self) -> None:
|
||||
pass
|
||||
|
||||
def adjust_layout(self, mem_stats_collector: MemStatsCollector) -> None:
|
||||
""" Adjust the layout of statefuil tensor according to the information provided
|
||||
by mem_stats_collector, which should belongs to a Sharded Model.
|
||||
|
||||
Args:
|
||||
mem_stats_collector (MemStatsCollector): a collector, usually owned by a Sharded Model.
|
||||
It contains non-model footprint of a DNN model.
|
||||
"""
|
||||
# find stateful tensor in state COMPUTE
|
||||
move_to_cuda_tensor_list = []
|
||||
cuda_demand = 0
|
||||
used_cuda_model_data = 0
|
||||
hold_cuda_tensor_list = []
|
||||
for tensor in self._stateful_tensor_list:
|
||||
if tensor.state == TensorState.FREE:
|
||||
continue
|
||||
|
||||
if tensor.device.type == 'cuda':
|
||||
used_cuda_model_data += colo_tensor_mem_usage(tensor.payload)[0]
|
||||
if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]:
|
||||
hold_cuda_tensor_list.append(tensor)
|
||||
else:
|
||||
if tensor.state == TensorState.COMPUTE:
|
||||
move_to_cuda_tensor_list.append(tensor)
|
||||
cuda_demand += colo_tensor_mem_usage(tensor.payload)[0]
|
||||
|
||||
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
|
||||
max_cuda_non_model_data_per_period = max(mem_stats_collector.current_non_model_data('cuda'),
|
||||
mem_stats_collector.next_non_model_data('cuda'))
|
||||
cuda_capacity = colo_cuda_memory_capacity()
|
||||
cuda_model_data_period = cuda_capacity - max_cuda_non_model_data_per_period
|
||||
if cuda_model_data_period < used_cuda_model_data + cuda_demand:
|
||||
# move cuda_model_data_period - cuda_demand - used_cuda_model_data volume of tensor
|
||||
# Here use a naive eviction strategy.
|
||||
acc_size = 0
|
||||
for t in hold_cuda_tensor_list:
|
||||
if acc_size > cuda_demand:
|
||||
break
|
||||
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
|
||||
t_size = colo_tensor_mem_usage(t)
|
||||
acc_size += t_size
|
||||
if acc_size < cuda_demand:
|
||||
raise RuntimeError("Adjust layout failed! No enough CUDA memory!")
|
||||
|
||||
# move COMPUTE tensors to CUDA
|
||||
for t in move_to_cuda_tensor_list:
|
||||
colo_model_data_tensor_move_inline(t, get_current_device())
|
|
@ -3,6 +3,7 @@ from colossalai.zero.sharded_param import ShardedTensor
|
|||
from typing import Optional, Tuple
|
||||
from colossalai.zero.shard_utils.tensor_utils import colo_tensor_mem_usage
|
||||
from .tensorful_state import StatefulTensor, TensorState
|
||||
from typing import List
|
||||
|
||||
|
||||
class ShardedParamV2(object):
|
||||
|
@ -22,6 +23,11 @@ class ShardedParamV2(object):
|
|||
if rm_torch_payload:
|
||||
self.remove_torch_payload()
|
||||
|
||||
def get_payload_tensors(self) -> List[StatefulTensor]:
|
||||
"""returns stateful tensors kept by this class.
|
||||
"""
|
||||
return [self._sharded_data_tensor, self.saved_grad]
|
||||
|
||||
def remove_torch_payload(self):
|
||||
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)
|
||||
|
||||
|
|
Loading…
Reference in New Issue