2022-11-16 07:45:57 +00:00
|
|
|
from colossalai.gemini.chunk import ChunkManager
|
|
|
|
from colossalai.utils import get_current_device
|
|
|
|
from colossalai.utils.memory import colo_device_memory_capacity
|
|
|
|
|
|
|
|
from .memstats_collector import MemStatsCollector
|
|
|
|
|
|
|
|
|
|
|
|
class ChunkMemStatsCollector(MemStatsCollector):
|
|
|
|
|
|
|
|
def __init__(self, chunk_manager: ChunkManager) -> None:
|
|
|
|
super().__init__()
|
|
|
|
self._chunk_manager = chunk_manager
|
|
|
|
|
2022-12-06 08:43:06 +00:00
|
|
|
# override
|
2022-11-16 07:45:57 +00:00
|
|
|
def sample_model_data(self) -> None:
|
|
|
|
"""Sampling model data statistics.
|
|
|
|
"""
|
|
|
|
if self._start_flag:
|
|
|
|
cuda_mem = self._chunk_manager.total_mem['cuda']
|
|
|
|
cpu_mem = self._chunk_manager.total_mem['cpu']
|
2022-12-06 08:43:06 +00:00
|
|
|
self._memstats.append_model_data('cuda', cuda_mem)
|
|
|
|
self._memstats.append_model_data('cpu', cpu_mem)
|
2022-11-16 07:45:57 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def cuda_margin_mem(self) -> float:
|
2022-12-06 08:43:06 +00:00
|
|
|
return colo_device_memory_capacity(get_current_device()) - self._memstats.max_overall_cuda('cuda')
|