mirror of https://github.com/hpcaitech/ColossalAI
[profiler] primary memory tracer
parent
dfc3fafe89
commit
d344689274
|
@ -26,6 +26,8 @@ class Engine:
|
|||
:type gradient_handlers: list
|
||||
:param clip_grad_norm: The norm of gradient clipping
|
||||
:type clip_grad_norm: float, optional
|
||||
:param ophook_list: List of ophook
|
||||
:type ophook_list: list
|
||||
:param verbose: whether to display log info
|
||||
:type verbose: bool
|
||||
"""
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from re import S
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
import torch
|
||||
from . import BaseOpHook
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
@ -5,18 +7,20 @@ from colossalai.registry import OPHOOKS
|
|||
from colossalai.logging import get_dist_logger
|
||||
from time import sleep, time
|
||||
import pickle
|
||||
from typing import Union, Optional
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
|
||||
def get_cuda_memory_used(device):
|
||||
def get_cuda_memory_used(device: Optional[torch.device]) -> int:
|
||||
"""
|
||||
Get the free memory info of device.
|
||||
Notice that for CPU, this function will return 1/N of the total free memory,
|
||||
where N is the world size.
|
||||
"""
|
||||
ret = torch.cuda.memory_allocated()
|
||||
ret: int = torch.cuda.memory_allocated(device)
|
||||
# get the peak memory to report correct data, so reset the counter for the next call
|
||||
if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.reset_peak_memory_stats(device)
|
||||
return ret
|
||||
|
||||
|
||||
|
@ -34,6 +38,9 @@ class AsyncMemoryMonitor:
|
|||
self.time_stamps = []
|
||||
self.mem_stats = []
|
||||
|
||||
def __len__(self):
|
||||
return len(self.mem_stats)
|
||||
|
||||
def set_interval(self, power: int):
|
||||
self.interval = 1 / (10**power)
|
||||
|
||||
|
@ -74,22 +81,65 @@ class AsyncMemoryMonitor:
|
|||
def save(self, filename):
|
||||
with open(filename, "wb") as f:
|
||||
pickle.dump(self.state_dict(), f)
|
||||
|
||||
def clear(self):
|
||||
self.mem_stats.clear()
|
||||
self.time_stamps.clear()
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
||||
class MemTracerOpHook(BaseOpHook):
|
||||
def __init__(self, niter=5):
|
||||
'''
|
||||
Collect GPU memory usage information
|
||||
|
||||
Args:
|
||||
warmup (int): This parameter indicates how many iterations to truncate
|
||||
before profiling, e.g. set to 5 and the data will start from 6-th iteration
|
||||
refreshrate (int): This parameter decides the frequency of write file.
|
||||
datafile(string): the name of the stats data file
|
||||
Attributes:
|
||||
_warmup (int): warmup iterations
|
||||
_refreshrate(int): how many iterations we shall refresh the file
|
||||
_logger (colossalai.logging.logger): output log file
|
||||
_curiter (int): current iteration number
|
||||
_count (int): the number of times the data file was written
|
||||
_data_prefix (string): the prefix of the stats data file
|
||||
_rank (int): the rank of current node
|
||||
'''
|
||||
def __init__(self, warmup: int = 50, refreshrate: int = 10, data_prefix: str = "memstats"):
|
||||
super().__init__()
|
||||
self.async_mem_monitor = AsyncMemoryMonitor()
|
||||
self._niter = niter
|
||||
self._curiter = 0
|
||||
self._logger = get_dist_logger()
|
||||
self._count = 0
|
||||
self._warmup = warmup
|
||||
self._refreshrate = refreshrate
|
||||
self._data_prefix = data_prefix
|
||||
# in distributed environment
|
||||
if gpc.is_initialized(ParallelMode.GLOBAL):
|
||||
self._rank = gpc.get_global_rank()
|
||||
else:
|
||||
self._rank = 0
|
||||
|
||||
def _isvalid(self, module):
|
||||
return module.training and self._curiter < self._niter
|
||||
def _isvalid(self, module) -> bool:
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
return module.training
|
||||
|
||||
def niter(self):
|
||||
return self._niter
|
||||
@property
|
||||
def refreshrate(self) -> int:
|
||||
return self._refreshrate
|
||||
|
||||
@property
|
||||
def warmup(self) -> int:
|
||||
return self._warmup
|
||||
|
||||
@property
|
||||
def curiter(self) -> int:
|
||||
return self._curiter
|
||||
|
||||
@property
|
||||
def valid_iter(self) -> int:
|
||||
return self.curiter - self.warmup
|
||||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
if self._isvalid(module):
|
||||
|
@ -103,14 +153,12 @@ class MemTracerOpHook(BaseOpHook):
|
|||
self._logger.debug(f'FWD POST {module.__class__.__name__}')
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
if self._isvalid(module):
|
||||
self.async_mem_monitor.finish()
|
||||
self.async_mem_monitor.start()
|
||||
self._logger.debug(f'BWD PRE {module.__class__.__name__}')
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
if self._isvalid(module):
|
||||
self.async_mem_monitor.finish()
|
||||
self._logger.debug(f'BWD POST {module.__class__.__name__}')
|
||||
|
@ -120,11 +168,24 @@ class MemTracerOpHook(BaseOpHook):
|
|||
|
||||
def post_iter(self):
|
||||
self.async_mem_monitor.finish()
|
||||
if self._curiter == self._niter:
|
||||
self._logger.info(
|
||||
f'dump a memory statistics as pickle to ./memstats.pkl')
|
||||
self.save_results("memstats.pkl")
|
||||
# in the warmup stage
|
||||
if self._curiter < self.warmup:
|
||||
# TODO: record time and adaptively change sampling rate
|
||||
pass
|
||||
elif self._curiter == self._warmup:
|
||||
self.async_mem_monitor.clear()
|
||||
else:
|
||||
# every `refreshrate` times, refresh the file
|
||||
if self.valid_iter != 0 and self.valid_iter % self.refreshrate == 0:
|
||||
# output file info
|
||||
self._logger.info(
|
||||
f'dump a memory statistics as pickle to {self._dataprefix}-{self._rank}.pkl')
|
||||
self.save_results()
|
||||
self._count += 1
|
||||
self._logger.debug(f'data file has been refreshed {self._count} times')
|
||||
# finish a iteration
|
||||
self._curiter += 1
|
||||
|
||||
def save_results(self, filename):
|
||||
self.async_mem_monitor.save(filename)
|
||||
def save_results(self):
|
||||
datafile = f"{self._data_prefix}-{self._rank}.pkl"
|
||||
self.async_mem_monitor.save(datafile)
|
||||
|
|
|
@ -19,6 +19,11 @@ class Timer:
|
|||
def has_history(self):
|
||||
return len(self._history) != 0
|
||||
|
||||
@property
|
||||
def current_time(self) -> float:
|
||||
synchronize()
|
||||
return time.time()
|
||||
|
||||
def start(self):
|
||||
"""Fisrtly synchronize cuda, reset the clock and then start the timer.
|
||||
"""
|
||||
|
@ -27,6 +32,11 @@ class Timer:
|
|||
self._start_time = time.time()
|
||||
self._started = True
|
||||
|
||||
def lap(self):
|
||||
"""lap time and return elapsed time
|
||||
"""
|
||||
return self.current_time - self._start_time
|
||||
|
||||
def stop(self, keep_in_history: bool = False):
|
||||
"""Stop the timer and record the start-stop time interval.
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ def test_load_config():
|
|||
|
||||
@pytest.mark.cpu
|
||||
def test_load_ophooks():
|
||||
dict = {'type': 'MemTracerOpHook', 'niter': 2}
|
||||
dict = {'type': 'MemTracerOpHook', 'warmup': 10, 'refreshrate': 20}
|
||||
ophook = build_ophooks(dict)
|
||||
assert ophook.niter() == 2
|
||||
assert ophook.refreshrate == 20
|
||||
assert ophook.warmup == 10
|
||||
|
|
Loading…
Reference in New Issue