mirror of https://github.com/hpcaitech/ColossalAI
[profiler] add adaptive sampling to memory profiler (#330)
* fix merge conflict modify unit test remove unnessesary log info reformat file * remove unused module * remove unnecessary sync function * change doc string style from Google to Sphinxpull/394/head
parent
1388671699
commit
3213554cc2
|
@ -8,13 +8,18 @@ from time import sleep, time
|
|||
import pickle
|
||||
from typing import Optional
|
||||
from colossalai.core import global_context as gpc
|
||||
import math
|
||||
|
||||
|
||||
def get_cuda_memory_used(device: Optional[torch.device]) -> int:
|
||||
"""
|
||||
Get the free memory info of device.
|
||||
"""Get the free memory info of device.
|
||||
Notice that for CPU, this function will return 1/N of the total free memory,
|
||||
where N is the world size.
|
||||
|
||||
:param device: device id
|
||||
:type device: torch.device
|
||||
:return: current memory usage, sized by MB
|
||||
:rtype: int
|
||||
"""
|
||||
ret: int = torch.cuda.memory_allocated(device)
|
||||
# get the peak memory to report correct data, so reset the counter for the next call
|
||||
|
@ -24,13 +29,16 @@ def get_cuda_memory_used(device: Optional[torch.device]) -> int:
|
|||
|
||||
|
||||
class AsyncMemoryMonitor:
|
||||
|
||||
def __init__(self, power=10):
|
||||
"""
|
||||
An Async Mem Monitor runing during computing.
|
||||
Sampling GPU memory usage of the current GPU dev
|
||||
"""
|
||||
An Async Mem Monitor runing during computing. Sampling GPU memory usage of the current GPU
|
||||
at interval of 1/(10**power) sec.
|
||||
|
||||
:param power: the power of time interval, defaults to 10
|
||||
:type power: int
|
||||
"""
|
||||
|
||||
def __init__(self, power: int = 10):
|
||||
|
||||
self.keep_measuring = False
|
||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||
self.monitor_thread = None
|
||||
|
@ -42,6 +50,7 @@ class AsyncMemoryMonitor:
|
|||
return len(self.mem_stats)
|
||||
|
||||
def set_interval(self, power: int):
|
||||
self.clear()
|
||||
self.interval = 1 / (10**power)
|
||||
|
||||
def is_measuring(self):
|
||||
|
@ -89,23 +98,16 @@ class AsyncMemoryMonitor:
|
|||
|
||||
@OPHOOKS.register_module
|
||||
class MemTracerOpHook(BaseOpHook):
|
||||
'''
|
||||
"""
|
||||
Collect GPU memory usage information
|
||||
|
||||
Args:
|
||||
warmup (int): This parameter indicates how many iterations to truncate
|
||||
before profiling, e.g. set to 5 and the data will start from 6-th iteration
|
||||
refreshrate (int): This parameter decides the frequency of write file.
|
||||
datafile(string): the name of the stats data file
|
||||
Attributes:
|
||||
_warmup (int): warmup iterations
|
||||
_refreshrate(int): how many iterations we shall refresh the file
|
||||
_logger (colossalai.logging.logger): output log file
|
||||
_curiter (int): current iteration number
|
||||
_count (int): the number of times the data file was written
|
||||
_data_prefix (string): the prefix of the stats data file
|
||||
_rank (int): the rank of current node
|
||||
'''
|
||||
:param warmup: This parameter indicates how many iterations to truncate before profiling, defaults to 50
|
||||
:type warmup: int
|
||||
:param refreshrate: This parameter decides the frequency of write file, defaults to 10
|
||||
:type refreshrate: int
|
||||
:param data_prefix: The prefix of the stats data file, defaults to "memstats"
|
||||
:type data_prefix: string
|
||||
"""
|
||||
|
||||
def __init__(self, warmup: int = 50, refreshrate: int = 10, data_prefix: str = "memstats"):
|
||||
super().__init__()
|
||||
|
@ -126,6 +128,16 @@ class MemTracerOpHook(BaseOpHook):
|
|||
assert isinstance(module, torch.nn.Module)
|
||||
return module.training
|
||||
|
||||
def _resample(self):
|
||||
# calculate the average iteration time
|
||||
total_time = (self.async_mem_monitor.time_stamps[-1] - self.async_mem_monitor.time_stamps[0])
|
||||
avg_it_time = total_time / self.warmup
|
||||
self._logger.debug(f"total time for {self.warmup} iterations is {total_time}s")
|
||||
# adjust the sampling power
|
||||
power: int = round(-math.log(avg_it_time, 10)) + 1
|
||||
self._logger.debug(f"the power is {power}")
|
||||
self.async_mem_monitor.set_interval(power)
|
||||
|
||||
@property
|
||||
def refreshrate(self) -> int:
|
||||
return self._refreshrate
|
||||
|
@ -146,23 +158,19 @@ class MemTracerOpHook(BaseOpHook):
|
|||
if self._isvalid(module):
|
||||
self.async_mem_monitor.finish()
|
||||
self.async_mem_monitor.start()
|
||||
self._logger.debug(f'FWD PRE {module.__class__.__name__}')
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
if self._isvalid(module):
|
||||
self.async_mem_monitor.finish()
|
||||
self._logger.debug(f'FWD POST {module.__class__.__name__}')
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
if self._isvalid(module):
|
||||
self.async_mem_monitor.finish()
|
||||
self.async_mem_monitor.start()
|
||||
self._logger.debug(f'BWD PRE {module.__class__.__name__}')
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
if self._isvalid(module):
|
||||
self.async_mem_monitor.finish()
|
||||
self._logger.debug(f'BWD POST {module.__class__.__name__}')
|
||||
|
||||
def pre_iter(self):
|
||||
pass
|
||||
|
@ -170,19 +178,21 @@ class MemTracerOpHook(BaseOpHook):
|
|||
def post_iter(self):
|
||||
self.async_mem_monitor.finish()
|
||||
# in the warmup stage
|
||||
if self._curiter < self.warmup:
|
||||
# TODO: record time and adaptively change sampling rate
|
||||
if self.curiter < self.warmup:
|
||||
pass
|
||||
elif self._curiter == self._warmup:
|
||||
self.async_mem_monitor.clear()
|
||||
# adjust the sampling rate
|
||||
elif self.curiter == self.warmup:
|
||||
# use adaptive sample rate
|
||||
self._resample()
|
||||
# record data to log file
|
||||
else:
|
||||
# every `refreshrate` times, refresh the file
|
||||
if self.valid_iter != 0 and self.valid_iter % self.refreshrate == 0:
|
||||
# output file info
|
||||
self._logger.info(f'dump a memory statistics as pickle to {self._dataprefix}-{self._rank}.pkl')
|
||||
self._logger.info(f"dump a memory statistics as pickle to {self._data_prefix}-{self._rank}.pkl")
|
||||
self.save_results()
|
||||
self._count += 1
|
||||
self._logger.debug(f'data file has been refreshed {self._count} times')
|
||||
self._logger.debug(f"data file has been refreshed {self._count} times")
|
||||
# finish a iteration
|
||||
self._curiter += 1
|
||||
|
||||
|
|
Loading…
Reference in New Issue