[profiler] add adaptive sampling to memory profiler (#330)

* fix merge conflict

modify unit test

remove unnessesary log info

reformat file

* remove unused module

* remove unnecessary sync function

* change doc string style from Google to Sphinx
pull/394/head
Jie Zhu 2022-03-09 11:07:10 +08:00 committed by Frank Lee
parent 1388671699
commit 3213554cc2
1 changed files with 42 additions and 32 deletions

View File

@ -8,13 +8,18 @@ from time import sleep, time
import pickle
from typing import Optional
from colossalai.core import global_context as gpc
import math
def get_cuda_memory_used(device: Optional[torch.device]) -> int:
"""
Get the free memory info of device.
"""Get the free memory info of device.
Notice that for CPU, this function will return 1/N of the total free memory,
where N is the world size.
:param device: device id
:type device: torch.device
:return: current memory usage, sized by MB
:rtype: int
"""
ret: int = torch.cuda.memory_allocated(device)
# get the peak memory to report correct data, so reset the counter for the next call
@ -24,13 +29,16 @@ def get_cuda_memory_used(device: Optional[torch.device]) -> int:
class AsyncMemoryMonitor:
def __init__(self, power=10):
"""
An Async Mem Monitor runing during computing.
Sampling GPU memory usage of the current GPU dev
"""
An Async Mem Monitor runing during computing. Sampling GPU memory usage of the current GPU
at interval of 1/(10**power) sec.
:param power: the power of time interval, defaults to 10
:type power: int
"""
def __init__(self, power: int = 10):
self.keep_measuring = False
self.executor = ThreadPoolExecutor(max_workers=1)
self.monitor_thread = None
@ -42,6 +50,7 @@ class AsyncMemoryMonitor:
return len(self.mem_stats)
def set_interval(self, power: int):
self.clear()
self.interval = 1 / (10**power)
def is_measuring(self):
@ -89,23 +98,16 @@ class AsyncMemoryMonitor:
@OPHOOKS.register_module
class MemTracerOpHook(BaseOpHook):
'''
"""
Collect GPU memory usage information
Args:
warmup (int): This parameter indicates how many iterations to truncate
before profiling, e.g. set to 5 and the data will start from 6-th iteration
refreshrate (int): This parameter decides the frequency of write file.
datafile(string): the name of the stats data file
Attributes:
_warmup (int): warmup iterations
_refreshrate(int): how many iterations we shall refresh the file
_logger (colossalai.logging.logger): output log file
_curiter (int): current iteration number
_count (int): the number of times the data file was written
_data_prefix (string): the prefix of the stats data file
_rank (int): the rank of current node
'''
:param warmup: This parameter indicates how many iterations to truncate before profiling, defaults to 50
:type warmup: int
:param refreshrate: This parameter decides the frequency of write file, defaults to 10
:type refreshrate: int
:param data_prefix: The prefix of the stats data file, defaults to "memstats"
:type data_prefix: string
"""
def __init__(self, warmup: int = 50, refreshrate: int = 10, data_prefix: str = "memstats"):
super().__init__()
@ -126,6 +128,16 @@ class MemTracerOpHook(BaseOpHook):
assert isinstance(module, torch.nn.Module)
return module.training
def _resample(self):
# calculate the average iteration time
total_time = (self.async_mem_monitor.time_stamps[-1] - self.async_mem_monitor.time_stamps[0])
avg_it_time = total_time / self.warmup
self._logger.debug(f"total time for {self.warmup} iterations is {total_time}s")
# adjust the sampling power
power: int = round(-math.log(avg_it_time, 10)) + 1
self._logger.debug(f"the power is {power}")
self.async_mem_monitor.set_interval(power)
@property
def refreshrate(self) -> int:
return self._refreshrate
@ -146,23 +158,19 @@ class MemTracerOpHook(BaseOpHook):
if self._isvalid(module):
self.async_mem_monitor.finish()
self.async_mem_monitor.start()
self._logger.debug(f'FWD PRE {module.__class__.__name__}')
def post_fwd_exec(self, module: torch.nn.Module, *args):
if self._isvalid(module):
self.async_mem_monitor.finish()
self._logger.debug(f'FWD POST {module.__class__.__name__}')
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
if self._isvalid(module):
self.async_mem_monitor.finish()
self.async_mem_monitor.start()
self._logger.debug(f'BWD PRE {module.__class__.__name__}')
def post_bwd_exec(self, module: torch.nn.Module, input):
if self._isvalid(module):
self.async_mem_monitor.finish()
self._logger.debug(f'BWD POST {module.__class__.__name__}')
def pre_iter(self):
pass
@ -170,19 +178,21 @@ class MemTracerOpHook(BaseOpHook):
def post_iter(self):
self.async_mem_monitor.finish()
# in the warmup stage
if self._curiter < self.warmup:
# TODO: record time and adaptively change sampling rate
if self.curiter < self.warmup:
pass
elif self._curiter == self._warmup:
self.async_mem_monitor.clear()
# adjust the sampling rate
elif self.curiter == self.warmup:
# use adaptive sample rate
self._resample()
# record data to log file
else:
# every `refreshrate` times, refresh the file
if self.valid_iter != 0 and self.valid_iter % self.refreshrate == 0:
# output file info
self._logger.info(f'dump a memory statistics as pickle to {self._dataprefix}-{self._rank}.pkl')
self._logger.info(f"dump a memory statistics as pickle to {self._data_prefix}-{self._rank}.pkl")
self.save_results()
self._count += 1
self._logger.debug(f'data file has been refreshed {self._count} times')
self._logger.debug(f"data file has been refreshed {self._count} times")
# finish a iteration
self._curiter += 1