2022-03-04 01:35:23 +00:00
|
|
|
from colossalai.context.parallel_mode import ParallelMode
|
2022-01-25 14:20:54 +00:00
|
|
|
import torch
|
2022-03-09 08:31:25 +00:00
|
|
|
from colossalai.engine.ophooks import BaseOpHook
|
2022-01-25 14:20:54 +00:00
|
|
|
from colossalai.registry import OPHOOKS
|
|
|
|
from colossalai.logging import get_dist_logger
|
2022-03-04 01:35:23 +00:00
|
|
|
from colossalai.core import global_context as gpc
|
2022-03-09 03:07:10 +00:00
|
|
|
|
2022-03-09 08:31:25 +00:00
|
|
|
from colossalai.utils.memory_tracer import AsyncMemoryMonitor
|
2022-01-25 14:20:54 +00:00
|
|
|
|
2022-03-09 08:31:25 +00:00
|
|
|
import math
|
2022-01-25 14:20:54 +00:00
|
|
|
|
|
|
|
|
|
|
|
@OPHOOKS.register_module
|
|
|
|
class MemTracerOpHook(BaseOpHook):
|
2022-03-09 03:07:10 +00:00
|
|
|
"""
|
2022-03-04 01:35:23 +00:00
|
|
|
Collect GPU memory usage information
|
|
|
|
|
2022-03-09 03:07:10 +00:00
|
|
|
:param warmup: This parameter indicates how many iterations to truncate before profiling, defaults to 50
|
|
|
|
:type warmup: int
|
|
|
|
:param refreshrate: This parameter decides the frequency of write file, defaults to 10
|
|
|
|
:type refreshrate: int
|
|
|
|
:param data_prefix: The prefix of the stats data file, defaults to "memstats"
|
|
|
|
:type data_prefix: string
|
|
|
|
"""
|
2022-03-08 06:45:01 +00:00
|
|
|
|
2022-03-04 01:35:23 +00:00
|
|
|
def __init__(self, warmup: int = 50, refreshrate: int = 10, data_prefix: str = "memstats"):
|
2022-01-25 14:20:54 +00:00
|
|
|
super().__init__()
|
|
|
|
self.async_mem_monitor = AsyncMemoryMonitor()
|
|
|
|
self._curiter = 0
|
|
|
|
self._logger = get_dist_logger()
|
2022-03-04 01:35:23 +00:00
|
|
|
self._count = 0
|
|
|
|
self._warmup = warmup
|
|
|
|
self._refreshrate = refreshrate
|
|
|
|
self._data_prefix = data_prefix
|
|
|
|
# in distributed environment
|
|
|
|
if gpc.is_initialized(ParallelMode.GLOBAL):
|
|
|
|
self._rank = gpc.get_global_rank()
|
|
|
|
else:
|
|
|
|
self._rank = 0
|
|
|
|
|
|
|
|
def _isvalid(self, module) -> bool:
|
|
|
|
assert isinstance(module, torch.nn.Module)
|
|
|
|
return module.training
|
|
|
|
|
2022-03-09 03:07:10 +00:00
|
|
|
def _resample(self):
|
|
|
|
# calculate the average iteration time
|
|
|
|
total_time = (self.async_mem_monitor.time_stamps[-1] - self.async_mem_monitor.time_stamps[0])
|
|
|
|
avg_it_time = total_time / self.warmup
|
|
|
|
self._logger.debug(f"total time for {self.warmup} iterations is {total_time}s")
|
|
|
|
# adjust the sampling power
|
|
|
|
power: int = round(-math.log(avg_it_time, 10)) + 1
|
|
|
|
self._logger.debug(f"the power is {power}")
|
|
|
|
self.async_mem_monitor.set_interval(power)
|
|
|
|
|
2022-03-04 01:35:23 +00:00
|
|
|
@property
|
|
|
|
def refreshrate(self) -> int:
|
|
|
|
return self._refreshrate
|
2022-03-08 06:45:01 +00:00
|
|
|
|
2022-03-04 01:35:23 +00:00
|
|
|
@property
|
|
|
|
def warmup(self) -> int:
|
|
|
|
return self._warmup
|
2022-01-25 14:20:54 +00:00
|
|
|
|
2022-03-04 01:35:23 +00:00
|
|
|
@property
|
|
|
|
def curiter(self) -> int:
|
|
|
|
return self._curiter
|
2022-01-25 14:20:54 +00:00
|
|
|
|
2022-03-04 01:35:23 +00:00
|
|
|
@property
|
|
|
|
def valid_iter(self) -> int:
|
|
|
|
return self.curiter - self.warmup
|
2022-01-25 14:20:54 +00:00
|
|
|
|
|
|
|
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
|
|
|
if self._isvalid(module):
|
|
|
|
self.async_mem_monitor.finish()
|
|
|
|
self.async_mem_monitor.start()
|
|
|
|
|
|
|
|
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
|
|
|
if self._isvalid(module):
|
|
|
|
self.async_mem_monitor.finish()
|
|
|
|
|
|
|
|
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
|
|
|
if self._isvalid(module):
|
|
|
|
self.async_mem_monitor.finish()
|
|
|
|
self.async_mem_monitor.start()
|
|
|
|
|
|
|
|
def post_bwd_exec(self, module: torch.nn.Module, input):
|
|
|
|
if self._isvalid(module):
|
|
|
|
self.async_mem_monitor.finish()
|
|
|
|
|
|
|
|
def pre_iter(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def post_iter(self):
|
|
|
|
self.async_mem_monitor.finish()
|
2022-03-04 01:35:23 +00:00
|
|
|
# in the warmup stage
|
2022-03-09 03:07:10 +00:00
|
|
|
if self.curiter < self.warmup:
|
2022-03-04 01:35:23 +00:00
|
|
|
pass
|
2022-03-09 03:07:10 +00:00
|
|
|
# adjust the sampling rate
|
|
|
|
elif self.curiter == self.warmup:
|
|
|
|
# use adaptive sample rate
|
|
|
|
self._resample()
|
|
|
|
# record data to log file
|
2022-03-04 01:35:23 +00:00
|
|
|
else:
|
|
|
|
# every `refreshrate` times, refresh the file
|
|
|
|
if self.valid_iter != 0 and self.valid_iter % self.refreshrate == 0:
|
|
|
|
# output file info
|
2022-03-09 03:07:10 +00:00
|
|
|
self._logger.info(f"dump a memory statistics as pickle to {self._data_prefix}-{self._rank}.pkl")
|
2022-03-04 01:35:23 +00:00
|
|
|
self.save_results()
|
|
|
|
self._count += 1
|
2022-03-09 03:07:10 +00:00
|
|
|
self._logger.debug(f"data file has been refreshed {self._count} times")
|
2022-03-04 01:35:23 +00:00
|
|
|
# finish a iteration
|
2022-01-25 14:20:54 +00:00
|
|
|
self._curiter += 1
|
|
|
|
|
2022-03-04 01:35:23 +00:00
|
|
|
def save_results(self):
|
|
|
|
datafile = f"{self._data_prefix}-{self._rank}.pkl"
|
|
|
|
self.async_mem_monitor.save(datafile)
|