ColossalAI/colossalai/utils/memory_tracer/memstats_collector.py

167 lines
5.5 KiB
Python
Raw Normal View History

from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
2022-04-01 01:22:33 +00:00
from colossalai.utils.memory_utils.utils import colo_device_memory_used
from colossalai.utils import get_current_device
from colossalai.utils.memory_tracer.async_memtracer import AsyncMemoryMonitor
import torch
import time
2022-04-01 01:22:33 +00:00
from typing import List
class SamplingCounter:
def __init__(self) -> None:
self._samplint_cnt = 0
self._max_sampling_cnt = None
def advance(self):
self._samplint_cnt += 1
def next(self):
assert self._max_sampling_cnt is not None
return (self._samplint_cnt + 1) % self._max_sampling_cnt
def current(self):
return self._samplint_cnt
def max(self):
return self._max_sampling_cnt
def reset(self):
self._max_sampling_cnt = self._samplint_cnt
self._samplint_cnt = 0
class MemStatsCollector:
2022-04-01 01:22:33 +00:00
"""
A Memory statistic collector.
It works in two phases.
Phase 1. Collection Phase: collect memory usage statistics of CPU and GPU.
The first iteration of DNN training.
Phase 2. Runtime Phase: use the read-only collected stats
The rest iterations of DNN training.
2022-04-01 01:22:33 +00:00
It has a Sampling counter which is reset after DNN training iteration.
"""
def __init__(self) -> None:
self._sampling_cnter = SamplingCounter()
self._mem_monitor = AsyncMemoryMonitor()
2022-04-01 01:22:33 +00:00
self._model_data_cuda_list = []
self._overall_cuda_list = []
2022-04-01 01:22:33 +00:00
self._model_data_cpu_list = []
self._overall_cpu_list = []
self._non_model_data_cuda_list = []
self._non_model_data_cpu_list = []
self._sampling_time = []
self._start_flag = False
2022-04-01 01:22:33 +00:00
def overall_mem_stats(self, device_type: str):
if device_type == 'cuda':
return self._overall_cuda_list
elif device_type == 'cpu':
return self._overall_cpu_list
else:
raise TypeError
def model_data_list(self, device_type: str, unit: str = 'B') -> List[int]:
2022-04-01 01:22:33 +00:00
if unit == 'GB':
scale = 1e9
elif unit == 'MB':
scale = 1e6
elif unit == 'KB':
scale = 1e3
elif unit == 'B':
scale = 1
2022-04-01 01:22:33 +00:00
else:
raise TypeError
if device_type == 'cuda':
return [elem / scale for elem in self._model_data_cuda_list]
elif device_type == 'cpu':
return [elem / scale for elem in self._model_data_cpu_list]
else:
raise TypeError
def non_model_data_list(self, device_type: str, unit: str = 'B') -> List[int]:
"""Non model data stats
"""
2022-04-01 01:22:33 +00:00
if unit == 'GB':
scale = 1e9
elif unit == 'MB':
scale = 1e6
elif unit == 'KB':
scale = 1e3
elif unit == 'B':
scale = 1
else:
raise TypeError
2022-04-01 01:22:33 +00:00
if device_type == 'cuda':
return [elem / scale for elem in self._non_model_data_cuda_list]
2022-04-01 01:22:33 +00:00
elif device_type == 'cpu':
return [elem / scale for elem in self._non_model_data_cpu_list]
2022-04-01 01:22:33 +00:00
else:
raise TypeError
def current_non_model_data(self, device_type: str) -> int:
"""get the non model data of the current sampling moment
"""
return self.non_model_data_list(device_type)[self._sampling_cnter.current()]
def next_non_model_data(self, device_type: str):
"""get the non model data of the next sampling moment
"""
return self.non_model_data_list(device_type)[self._sampling_cnter.next()]
@property
def sampling_time(self):
return [t - self._sampling_time[0] for t in self._sampling_time]
def start_collection(self):
self._start_flag = True
self._mem_monitor.start()
def finish_collection(self):
self._start_flag = False
def sample_memstats(self) -> None:
"""
Sampling memory statistics.
Record the current model data CUDA memory usage as well as system CUDA memory usage.
2022-04-01 01:22:33 +00:00
Advance the sampling cnter.
"""
if self._start_flag:
sampling_cnt = self._sampling_cnter.current()
2022-04-01 01:22:33 +00:00
assert sampling_cnt == len(self._overall_cuda_list)
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
self._overall_cuda_list.append(self._mem_monitor.finish())
self._non_model_data_cuda_list.append(self._model_data_cuda_list[-1] - self._overall_cuda_list[-1])
2022-04-01 01:22:33 +00:00
self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage)
# FIXME(jiaruifang) cpu sys used should also return from self._mem_monitor()
2022-04-01 01:22:33 +00:00
self._overall_cpu_list.append(colo_device_memory_used(torch.device(f'cpu')))
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
self._sampling_time.append(time.time())
self._mem_monitor.start()
# TODO(ver217): refactor sampler
# print(f'{self._sampling_cnter.current()} / {self._sampling_cnter.max()}, len = {len(self._sampling_time)}')
2022-04-01 01:22:33 +00:00
self._sampling_cnter.advance()
def reset_sampling_cnter(self) -> None:
self._sampling_cnter.reset()
self._mem_monitor.finish()
def clear(self) -> None:
2022-04-01 01:22:33 +00:00
self._model_data_cuda_list = []
self._overall_cuda_list = []
2022-04-01 01:22:33 +00:00
self._model_data_cpu_list = []
self._overall_cpu_list = []
self._start_flag = False
self._sampling_cnter.reset()
self._mem_monitor.finish()