2022-03-15 03:29:46 +00:00
|
|
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
2022-04-01 01:22:33 +00:00
|
|
|
from colossalai.utils.memory_utils.utils import colo_device_memory_used
|
2022-03-14 14:05:30 +00:00
|
|
|
from colossalai.utils import get_current_device
|
2022-04-03 13:48:06 +00:00
|
|
|
from colossalai.utils.memory_tracer.async_memtracer import AsyncMemoryMonitor
|
2022-03-14 14:05:30 +00:00
|
|
|
import torch
|
2022-04-01 06:03:00 +00:00
|
|
|
import time
|
2022-04-01 01:22:33 +00:00
|
|
|
from typing import List
|
2022-03-14 14:05:30 +00:00
|
|
|
|
|
|
|
|
|
|
|
class SamplingCounter:
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
self._samplint_cnt = 0
|
2022-04-06 08:18:49 +00:00
|
|
|
self._max_sampling_cnt = None
|
2022-03-14 14:05:30 +00:00
|
|
|
|
|
|
|
def advance(self):
|
|
|
|
self._samplint_cnt += 1
|
|
|
|
|
2022-04-06 08:18:49 +00:00
|
|
|
def next(self):
|
|
|
|
assert self._max_sampling_cnt is not None
|
|
|
|
return (self._samplint_cnt + 1) % self._max_sampling_cnt
|
|
|
|
|
2022-04-08 09:51:34 +00:00
|
|
|
def current(self):
|
2022-03-14 14:05:30 +00:00
|
|
|
return self._samplint_cnt
|
|
|
|
|
2022-04-08 09:51:34 +00:00
|
|
|
def max(self):
|
|
|
|
return self._max_sampling_cnt
|
|
|
|
|
2022-03-14 14:05:30 +00:00
|
|
|
def reset(self):
|
2022-04-06 08:18:49 +00:00
|
|
|
self._max_sampling_cnt = self._samplint_cnt
|
2022-03-14 14:05:30 +00:00
|
|
|
self._samplint_cnt = 0
|
|
|
|
|
|
|
|
|
|
|
|
class MemStatsCollector:
|
2022-04-01 01:22:33 +00:00
|
|
|
"""
|
|
|
|
A Memory statistic collector.
|
|
|
|
It works in two phases.
|
|
|
|
Phase 1. Collection Phase: collect memory usage statistics of CPU and GPU.
|
|
|
|
The first iteration of DNN training.
|
|
|
|
Phase 2. Runtime Phase: use the read-only collected stats
|
|
|
|
The rest iterations of DNN training.
|
2022-04-08 09:51:34 +00:00
|
|
|
|
2022-04-01 01:22:33 +00:00
|
|
|
It has a Sampling counter which is reset after DNN training iteration.
|
|
|
|
"""
|
2022-03-14 14:05:30 +00:00
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
self._sampling_cnter = SamplingCounter()
|
2022-04-03 13:48:06 +00:00
|
|
|
self._mem_monitor = AsyncMemoryMonitor()
|
2022-04-01 01:22:33 +00:00
|
|
|
self._model_data_cuda_list = []
|
|
|
|
self._overall_cuda_list = []
|
2022-03-14 14:05:30 +00:00
|
|
|
|
2022-04-01 01:22:33 +00:00
|
|
|
self._model_data_cpu_list = []
|
|
|
|
self._overall_cpu_list = []
|
2022-03-14 14:05:30 +00:00
|
|
|
|
2022-04-08 09:51:34 +00:00
|
|
|
self._non_model_data_cuda_list = []
|
|
|
|
self._non_model_data_cpu_list = []
|
2022-04-01 06:03:00 +00:00
|
|
|
self._sampling_time = []
|
|
|
|
|
2022-03-14 14:05:30 +00:00
|
|
|
self._start_flag = False
|
|
|
|
|
2022-04-01 01:22:33 +00:00
|
|
|
def overall_mem_stats(self, device_type: str):
|
|
|
|
if device_type == 'cuda':
|
|
|
|
return self._overall_cuda_list
|
|
|
|
elif device_type == 'cpu':
|
|
|
|
return self._overall_cpu_list
|
|
|
|
else:
|
|
|
|
raise TypeError
|
2022-03-30 01:38:44 +00:00
|
|
|
|
2022-04-06 08:18:49 +00:00
|
|
|
def model_data_list(self, device_type: str, unit: str = 'B') -> List[int]:
|
2022-04-01 01:22:33 +00:00
|
|
|
if unit == 'GB':
|
|
|
|
scale = 1e9
|
|
|
|
elif unit == 'MB':
|
|
|
|
scale = 1e6
|
|
|
|
elif unit == 'KB':
|
|
|
|
scale = 1e3
|
2022-04-01 06:03:00 +00:00
|
|
|
elif unit == 'B':
|
|
|
|
scale = 1
|
2022-04-01 01:22:33 +00:00
|
|
|
else:
|
|
|
|
raise TypeError
|
|
|
|
|
|
|
|
if device_type == 'cuda':
|
|
|
|
return [elem / scale for elem in self._model_data_cuda_list]
|
|
|
|
elif device_type == 'cpu':
|
|
|
|
return [elem / scale for elem in self._model_data_cpu_list]
|
|
|
|
else:
|
|
|
|
raise TypeError
|
|
|
|
|
2022-04-06 08:18:49 +00:00
|
|
|
def non_model_data_list(self, device_type: str, unit: str = 'B') -> List[int]:
|
2022-03-28 08:38:18 +00:00
|
|
|
"""Non model data stats
|
|
|
|
"""
|
2022-04-01 01:22:33 +00:00
|
|
|
if unit == 'GB':
|
|
|
|
scale = 1e9
|
|
|
|
elif unit == 'MB':
|
|
|
|
scale = 1e6
|
|
|
|
elif unit == 'KB':
|
|
|
|
scale = 1e3
|
2022-04-01 06:03:00 +00:00
|
|
|
elif unit == 'B':
|
|
|
|
scale = 1
|
|
|
|
else:
|
|
|
|
raise TypeError
|
2022-04-01 01:22:33 +00:00
|
|
|
|
|
|
|
if device_type == 'cuda':
|
2022-04-08 09:51:34 +00:00
|
|
|
return [elem / scale for elem in self._non_model_data_cuda_list]
|
2022-04-01 01:22:33 +00:00
|
|
|
elif device_type == 'cpu':
|
2022-04-08 09:51:34 +00:00
|
|
|
return [elem / scale for elem in self._non_model_data_cpu_list]
|
2022-04-01 01:22:33 +00:00
|
|
|
else:
|
|
|
|
raise TypeError
|
2022-03-28 08:38:18 +00:00
|
|
|
|
2022-04-06 08:18:49 +00:00
|
|
|
def current_non_model_data(self, device_type: str) -> int:
|
2022-04-08 09:51:34 +00:00
|
|
|
"""get the non model data of the current sampling moment
|
2022-04-06 08:18:49 +00:00
|
|
|
"""
|
2022-04-08 09:51:34 +00:00
|
|
|
return self.non_model_data_list(device_type)[self._sampling_cnter.current()]
|
2022-04-06 08:18:49 +00:00
|
|
|
|
|
|
|
def next_non_model_data(self, device_type: str):
|
2022-04-08 09:51:34 +00:00
|
|
|
"""get the non model data of the next sampling moment
|
|
|
|
"""
|
2022-04-06 08:18:49 +00:00
|
|
|
return self.non_model_data_list(device_type)[self._sampling_cnter.next()]
|
|
|
|
|
2022-04-01 06:03:00 +00:00
|
|
|
@property
|
|
|
|
def sampling_time(self):
|
|
|
|
return [t - self._sampling_time[0] for t in self._sampling_time]
|
|
|
|
|
2022-03-14 14:05:30 +00:00
|
|
|
def start_collection(self):
|
|
|
|
self._start_flag = True
|
2022-04-03 13:48:06 +00:00
|
|
|
self._mem_monitor.start()
|
2022-03-14 14:05:30 +00:00
|
|
|
|
|
|
|
def finish_collection(self):
|
|
|
|
self._start_flag = False
|
|
|
|
|
|
|
|
def sample_memstats(self) -> None:
|
|
|
|
"""
|
|
|
|
Sampling memory statistics.
|
|
|
|
Record the current model data CUDA memory usage as well as system CUDA memory usage.
|
2022-04-01 01:22:33 +00:00
|
|
|
Advance the sampling cnter.
|
2022-03-14 14:05:30 +00:00
|
|
|
"""
|
|
|
|
if self._start_flag:
|
2022-04-08 09:51:34 +00:00
|
|
|
sampling_cnt = self._sampling_cnter.current()
|
2022-04-01 01:22:33 +00:00
|
|
|
assert sampling_cnt == len(self._overall_cuda_list)
|
|
|
|
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
|
2022-04-03 13:48:06 +00:00
|
|
|
self._overall_cuda_list.append(self._mem_monitor.finish())
|
2022-04-08 09:51:34 +00:00
|
|
|
self._non_model_data_cuda_list.append(self._model_data_cuda_list[-1] - self._overall_cuda_list[-1])
|
2022-03-14 14:05:30 +00:00
|
|
|
|
2022-04-01 01:22:33 +00:00
|
|
|
self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage)
|
2022-04-08 09:51:34 +00:00
|
|
|
# FIXME(jiaruifang) cpu sys used should also return from self._mem_monitor()
|
2022-04-01 01:22:33 +00:00
|
|
|
self._overall_cpu_list.append(colo_device_memory_used(torch.device(f'cpu')))
|
2022-04-08 09:51:34 +00:00
|
|
|
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
|
2022-04-01 06:03:00 +00:00
|
|
|
self._sampling_time.append(time.time())
|
2022-04-03 13:48:06 +00:00
|
|
|
self._mem_monitor.start()
|
2022-04-08 09:51:34 +00:00
|
|
|
# TODO(ver217): refactor sampler
|
|
|
|
# print(f'{self._sampling_cnter.current()} / {self._sampling_cnter.max()}, len = {len(self._sampling_time)}')
|
2022-04-01 01:22:33 +00:00
|
|
|
self._sampling_cnter.advance()
|
2022-03-14 14:05:30 +00:00
|
|
|
|
|
|
|
def reset_sampling_cnter(self) -> None:
|
|
|
|
self._sampling_cnter.reset()
|
2022-04-03 13:48:06 +00:00
|
|
|
self._mem_monitor.finish()
|
2022-03-14 14:05:30 +00:00
|
|
|
|
|
|
|
def clear(self) -> None:
|
2022-04-01 01:22:33 +00:00
|
|
|
self._model_data_cuda_list = []
|
|
|
|
self._overall_cuda_list = []
|
2022-03-14 14:05:30 +00:00
|
|
|
|
2022-04-01 01:22:33 +00:00
|
|
|
self._model_data_cpu_list = []
|
|
|
|
self._overall_cpu_list = []
|
2022-03-14 14:05:30 +00:00
|
|
|
|
|
|
|
self._start_flag = False
|
|
|
|
self._sampling_cnter.reset()
|
2022-04-08 09:51:34 +00:00
|
|
|
self._mem_monitor.finish()
|