You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/gemini/memory_tracer/memstats_collector.py

101 lines
3.6 KiB

import time
from typing import List, Optional
import torch
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.utils.memory import colo_device_memory_used
from .memory_stats import MemStats
class MemStatsCollector:
"""
A Memory statistic collector.
It works in two phases.
Phase 1. Collection Phase: collect memory usage statistics of CPU and GPU.
The first iteration of DNN training.
Phase 2. Runtime Phase: use the read-only collected stats
The rest iterations of DNN training.
It has a Sampling counter which is reset after DNN training iteration.
"""
def __init__(self, memstats: Optional[MemStats] = None) -> None:
self._mem_monitor = SyncCudaMemoryMonitor()
self._sampling_time = []
self._start_flag = False
self._step_idx = 0
self._step_total = 0
if memstats is not None:
self.use_outside_memstats = True
self._memstats = memstats
else:
self.use_outside_memstats = False
self._memstats = MemStats()
def next_period_non_model_data_usage(self, device_type: str) -> int:
"""Maximum non model data memory usage during the next Op run
Args:
device_type (str): device type, can be 'cpu' or 'cuda'.
Returns:
int: max non model data memory usage of current sampling period
"""
assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \
f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\
f"step total {self._step_total}"
next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx]
self._step_idx = (self._step_idx + 1) % self._step_total
return next_non_model_data
@property
def sampling_time(self):
return [t - self._sampling_time[0] for t in self._sampling_time]
def start_collection(self):
self._start_flag = True
self._mem_monitor.start()
def finish_collection(self):
self.sample_overall_data()
# self._step_total = len(self._sampling_time)
self._step_total = len(self._memstats.non_model_data_list('cuda'))
self._start_flag = False
print(f'finish_collection {self._step_total}')
# deprecated
def record_model_data_volume(self) -> None:
"""
Sampling model data statistics.
"""
if self._start_flag and not self.use_outside_memstats:
# The following code work for ZeroInitContext, which is deprecated in v0.1.12
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
self._memstats.record_max_cuda_model_data(cuda_mem)
def sample_overall_data(self) -> None:
"""
Sampling overall and non model data cuda memory statistics.
"""
if self._start_flag and not self.use_outside_memstats:
cuda_overall = self._mem_monitor.finish()
self._memstats.record_max_cuda_overall_data(cuda_overall)
self._memstats.calc_max_cuda_non_model_data()
self._mem_monitor.start()
if self._start_flag:
self._sampling_time.append(time.time())
def clear(self) -> None:
self._memstats.clear()
self._start_flag = False
self._step_idx = 0
self._step_total = 0