2022-04-01 06:03:00 +00:00
|
|
|
import time
|
2022-12-07 15:04:02 +00:00
|
|
|
from typing import List, Optional
|
2022-11-07 08:49:03 +00:00
|
|
|
|
2022-11-16 07:45:57 +00:00
|
|
|
import torch
|
2022-11-07 08:49:03 +00:00
|
|
|
|
2022-11-16 07:45:57 +00:00
|
|
|
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
|
|
|
from colossalai.gemini.stateful_tensor import StatefulTensor
|
|
|
|
from colossalai.utils.memory import colo_device_memory_used
|
2022-03-14 14:05:30 +00:00
|
|
|
|
2022-12-06 08:43:06 +00:00
|
|
|
from .memory_stats import MemStats
|
|
|
|
|
2022-03-14 14:05:30 +00:00
|
|
|
|
|
|
|
class MemStatsCollector:
|
2022-04-01 01:22:33 +00:00
|
|
|
"""
|
|
|
|
A Memory statistic collector.
|
2022-11-16 07:45:57 +00:00
|
|
|
It works in two phases.
|
2022-04-01 01:22:33 +00:00
|
|
|
Phase 1. Collection Phase: collect memory usage statistics of CPU and GPU.
|
|
|
|
The first iteration of DNN training.
|
|
|
|
Phase 2. Runtime Phase: use the read-only collected stats
|
|
|
|
The rest iterations of DNN training.
|
2022-04-08 09:51:34 +00:00
|
|
|
|
2022-04-01 01:22:33 +00:00
|
|
|
It has a Sampling counter which is reset after DNN training iteration.
|
|
|
|
"""
|
2022-03-14 14:05:30 +00:00
|
|
|
|
2022-12-07 15:04:02 +00:00
|
|
|
def __init__(self, memstats: Optional[MemStats] = None) -> None:
|
2022-04-14 04:01:12 +00:00
|
|
|
self._mem_monitor = SyncCudaMemoryMonitor()
|
2022-04-01 06:03:00 +00:00
|
|
|
self._sampling_time = []
|
|
|
|
|
2022-03-14 14:05:30 +00:00
|
|
|
self._start_flag = False
|
2022-04-14 04:01:12 +00:00
|
|
|
self._step_idx = 0
|
|
|
|
self._step_total = 0
|
2022-12-07 15:04:02 +00:00
|
|
|
if memstats is not None:
|
|
|
|
self.use_outside_memstats = True
|
|
|
|
self._memstats = memstats
|
|
|
|
else:
|
|
|
|
self.use_outside_memstats = False
|
|
|
|
self._memstats = MemStats()
|
2022-03-28 08:38:18 +00:00
|
|
|
|
2022-04-14 04:01:12 +00:00
|
|
|
def next_period_non_model_data_usage(self, device_type: str) -> int:
|
2022-12-12 10:06:16 +00:00
|
|
|
"""Maximum non model data memory usage during the next Op run
|
2022-04-06 08:18:49 +00:00
|
|
|
|
2022-04-11 02:46:08 +00:00
|
|
|
Args:
|
|
|
|
device_type (str): device type, can be 'cpu' or 'cuda'.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
int: max non model data memory usage of current sampling period
|
2022-04-08 09:51:34 +00:00
|
|
|
"""
|
2022-04-11 02:46:08 +00:00
|
|
|
assert not self._start_flag, 'Cannot get mem stats info during collection phase.'
|
2022-04-14 04:01:12 +00:00
|
|
|
assert self._step_total > 0, 'Cannot get mem stats info before collection phase.'
|
2022-12-12 10:06:16 +00:00
|
|
|
assert len(self._memstats.non_model_data_list(device_type)) > self._step_idx, \
|
|
|
|
f"{len(self._memstats.non_model_data_list(device_type))} should be > than step idx {self._step_idx}, "\
|
|
|
|
f"step total {self._step_total}"
|
2022-12-06 08:43:06 +00:00
|
|
|
next_non_model_data = self._memstats.non_model_data_list(device_type)[self._step_idx]
|
2022-04-14 04:01:12 +00:00
|
|
|
self._step_idx = (self._step_idx + 1) % self._step_total
|
|
|
|
return next_non_model_data
|
2022-04-06 08:18:49 +00:00
|
|
|
|
2022-04-01 06:03:00 +00:00
|
|
|
@property
|
|
|
|
def sampling_time(self):
|
|
|
|
return [t - self._sampling_time[0] for t in self._sampling_time]
|
|
|
|
|
2022-03-14 14:05:30 +00:00
|
|
|
def start_collection(self):
|
|
|
|
self._start_flag = True
|
2022-04-03 13:48:06 +00:00
|
|
|
self._mem_monitor.start()
|
2022-03-14 14:05:30 +00:00
|
|
|
|
|
|
|
def finish_collection(self):
|
2022-04-14 04:01:12 +00:00
|
|
|
self.sample_overall_data()
|
2022-12-12 10:06:16 +00:00
|
|
|
# self._step_total = len(self._sampling_time)
|
|
|
|
self._step_total = len(self._memstats.non_model_data_list('cuda'))
|
2022-03-14 14:05:30 +00:00
|
|
|
self._start_flag = False
|
2022-04-11 02:46:08 +00:00
|
|
|
self._mem_monitor.finish()
|
2022-03-14 14:05:30 +00:00
|
|
|
|
2022-04-14 04:01:12 +00:00
|
|
|
def sample_model_data(self) -> None:
|
|
|
|
"""Sampling model data statistics.
|
|
|
|
"""
|
2022-12-07 15:04:02 +00:00
|
|
|
if self._start_flag and not self.use_outside_memstats:
|
2022-04-24 09:17:22 +00:00
|
|
|
cuda_mem = StatefulTensor.GST_MGR.total_mem['cuda']
|
|
|
|
cpu_mem = StatefulTensor.GST_MGR.total_mem['cpu']
|
2022-12-06 08:43:06 +00:00
|
|
|
self._memstats.append_model_data('cuda', cuda_mem)
|
|
|
|
self._memstats.append_model_data('cpu', cpu_mem)
|
2022-04-14 04:01:12 +00:00
|
|
|
|
|
|
|
def sample_overall_data(self) -> None:
|
|
|
|
"""Sampling non model data statistics.
|
|
|
|
"""
|
2022-12-07 15:04:02 +00:00
|
|
|
if self._start_flag and not self.use_outside_memstats:
|
2022-04-14 04:01:12 +00:00
|
|
|
# overall data recording is after model data recording
|
2022-12-06 08:43:06 +00:00
|
|
|
if len(self._memstats._model_data_cuda_list) == 0:
|
2022-04-14 04:01:12 +00:00
|
|
|
return
|
|
|
|
|
2022-12-06 08:43:06 +00:00
|
|
|
self._memstats.append_overall_data('cuda', self._mem_monitor.finish())
|
|
|
|
self._memstats.append_overall_data('cpu', colo_device_memory_used(torch.device('cpu')))
|
2022-04-14 04:01:12 +00:00
|
|
|
|
2022-12-06 08:43:06 +00:00
|
|
|
assert len(self._memstats._model_data_cuda_list) == len(self._memstats._overall_cuda_list)
|
2022-04-14 04:01:12 +00:00
|
|
|
|
2022-12-06 08:43:06 +00:00
|
|
|
self._memstats.append_non_model_data('cuda')
|
|
|
|
self._memstats.append_non_model_data('cpu')
|
2022-04-14 04:01:12 +00:00
|
|
|
self._mem_monitor.start()
|
|
|
|
|
2022-12-07 15:04:02 +00:00
|
|
|
if self._start_flag:
|
|
|
|
self._sampling_time.append(time.time())
|
|
|
|
|
2022-03-14 14:05:30 +00:00
|
|
|
def clear(self) -> None:
|
2022-12-06 08:43:06 +00:00
|
|
|
self._memstats.clear()
|
2022-03-14 14:05:30 +00:00
|
|
|
self._start_flag = False
|
2022-04-14 04:01:12 +00:00
|
|
|
self._step_idx = 0
|
|
|
|
self._step_total = 0
|