mirror of https://github.com/hpcaitech/ColossalAI
[zero] add sampling time for memstats collector (#610)
parent
9bee119104
commit
02b187c14f
|
@ -3,6 +3,7 @@ from colossalai.utils.memory_utils.utils import colo_device_memory_used
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import time
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
@ -42,6 +43,8 @@ class MemStatsCollector:
|
||||||
self._model_data_cpu_list = []
|
self._model_data_cpu_list = []
|
||||||
self._overall_cpu_list = []
|
self._overall_cpu_list = []
|
||||||
|
|
||||||
|
self._sampling_time = []
|
||||||
|
|
||||||
self._start_flag = False
|
self._start_flag = False
|
||||||
|
|
||||||
def overall_mem_stats(self, device_type: str):
|
def overall_mem_stats(self, device_type: str):
|
||||||
|
@ -52,15 +55,15 @@ class MemStatsCollector:
|
||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
@property
|
|
||||||
def model_data_cuda_list(self, device_type: str, unit: str = 'B') -> List[int]:
|
def model_data_cuda_list(self, device_type: str, unit: str = 'B') -> List[int]:
|
||||||
scale = 1
|
|
||||||
if unit == 'GB':
|
if unit == 'GB':
|
||||||
scale = 1e9
|
scale = 1e9
|
||||||
elif unit == 'MB':
|
elif unit == 'MB':
|
||||||
scale = 1e6
|
scale = 1e6
|
||||||
elif unit == 'KB':
|
elif unit == 'KB':
|
||||||
scale = 1e3
|
scale = 1e3
|
||||||
|
elif unit == 'B':
|
||||||
|
scale = 1
|
||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
|
@ -74,13 +77,16 @@ class MemStatsCollector:
|
||||||
def non_model_data_cuda_list(self, device_type: str, unit: str = 'B') -> List[int]:
|
def non_model_data_cuda_list(self, device_type: str, unit: str = 'B') -> List[int]:
|
||||||
"""Non model data stats
|
"""Non model data stats
|
||||||
"""
|
"""
|
||||||
scale = 1
|
|
||||||
if unit == 'GB':
|
if unit == 'GB':
|
||||||
scale = 1e9
|
scale = 1e9
|
||||||
elif unit == 'MB':
|
elif unit == 'MB':
|
||||||
scale = 1e6
|
scale = 1e6
|
||||||
elif unit == 'KB':
|
elif unit == 'KB':
|
||||||
scale = 1e3
|
scale = 1e3
|
||||||
|
elif unit == 'B':
|
||||||
|
scale = 1
|
||||||
|
else:
|
||||||
|
raise TypeError
|
||||||
|
|
||||||
if device_type == 'cuda':
|
if device_type == 'cuda':
|
||||||
return [(v1 - v2) / scale for v1, v2 in zip(self._overall_cuda_list, self._model_data_cuda_list)]
|
return [(v1 - v2) / scale for v1, v2 in zip(self._overall_cuda_list, self._model_data_cuda_list)]
|
||||||
|
@ -89,6 +95,10 @@ class MemStatsCollector:
|
||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sampling_time(self):
|
||||||
|
return [t - self._sampling_time[0] for t in self._sampling_time]
|
||||||
|
|
||||||
def start_collection(self):
|
def start_collection(self):
|
||||||
self._start_flag = True
|
self._start_flag = True
|
||||||
|
|
||||||
|
@ -110,6 +120,8 @@ class MemStatsCollector:
|
||||||
self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage)
|
self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage)
|
||||||
self._overall_cpu_list.append(colo_device_memory_used(torch.device(f'cpu')))
|
self._overall_cpu_list.append(colo_device_memory_used(torch.device(f'cpu')))
|
||||||
|
|
||||||
|
self._sampling_time.append(time.time())
|
||||||
|
|
||||||
self._sampling_cnter.advance()
|
self._sampling_cnter.advance()
|
||||||
|
|
||||||
def reset_sampling_cnter(self) -> None:
|
def reset_sampling_cnter(self) -> None:
|
||||||
|
|
Loading…
Reference in New Issue