[zero] add sampling time for memstats collector (#610)

pull/612/head
LuGY 2022-04-01 14:03:00 +08:00 committed by GitHub
parent 9bee119104
commit 02b187c14f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 3 deletions

View File

@ -3,6 +3,7 @@ from colossalai.utils.memory_utils.utils import colo_device_memory_used
from colossalai.utils import get_current_device
import torch
import time
from typing import List
@ -42,6 +43,8 @@ class MemStatsCollector:
self._model_data_cpu_list = []
self._overall_cpu_list = []
self._sampling_time = []
self._start_flag = False
def overall_mem_stats(self, device_type: str):
@ -52,15 +55,15 @@ class MemStatsCollector:
else:
raise TypeError
@property
def model_data_cuda_list(self, device_type: str, unit: str = 'B') -> List[int]:
scale = 1
if unit == 'GB':
scale = 1e9
elif unit == 'MB':
scale = 1e6
elif unit == 'KB':
scale = 1e3
elif unit == 'B':
scale = 1
else:
raise TypeError
@ -74,13 +77,16 @@ class MemStatsCollector:
def non_model_data_cuda_list(self, device_type: str, unit: str = 'B') -> List[int]:
"""Non model data stats
"""
scale = 1
if unit == 'GB':
scale = 1e9
elif unit == 'MB':
scale = 1e6
elif unit == 'KB':
scale = 1e3
elif unit == 'B':
scale = 1
else:
raise TypeError
if device_type == 'cuda':
return [(v1 - v2) / scale for v1, v2 in zip(self._overall_cuda_list, self._model_data_cuda_list)]
@ -89,6 +95,10 @@ class MemStatsCollector:
else:
raise TypeError
@property
def sampling_time(self):
return [t - self._sampling_time[0] for t in self._sampling_time]
def start_collection(self):
self._start_flag = True
@ -110,6 +120,8 @@ class MemStatsCollector:
self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage)
self._overall_cpu_list.append(colo_device_memory_used(torch.device(f'cpu')))
self._sampling_time.append(time.time())
self._sampling_cnter.advance()
def reset_sampling_cnter(self) -> None: