import json import pickle from pathlib import Path from colossalai.context.parallel_mode import ParallelMode import torch from colossalai.engine.ophooks import BaseOpHook from colossalai.registry import OPHOOKS from colossalai.logging import get_dist_logger from colossalai.core import global_context as gpc from typing import Union import math @OPHOOKS.register_module class MemTracerOpHook(BaseOpHook): """ Collect GPU memory usage information Args: warmup (int): This parameter indicates how many iterations to truncate before profiling, defaults to 50. refreshrate (int): This parameter decides the frequency of write file, defaults to 10. data_prefix (string): The prefix of the stats data file, defaults to "memstats". """ def __init__(self, warmup: int = 50, refreshrate: int = 10, data_prefix: str = "memstats"): from colossalai.gemini.memory_tracer import AsyncMemoryMonitor super().__init__() self.async_mem_monitor = AsyncMemoryMonitor() self._curiter = 0 self._logger = get_dist_logger() self._count = 0 self._warmup = warmup self._refreshrate = refreshrate self._data_prefix = data_prefix # in distributed environment if gpc.is_initialized(ParallelMode.GLOBAL): self._rank = gpc.get_global_rank() else: self._rank = 0 def _isvalid(self, module) -> bool: assert isinstance(module, torch.nn.Module) return module.training def _resample(self): # calculate the average iteration time total_time = (self.async_mem_monitor.time_stamps[-1] - self.async_mem_monitor.time_stamps[0]) avg_it_time = total_time / self.warmup self._logger.debug(f"total time for {self.warmup} iterations is {total_time}s") # adjust the sampling power power: int = round(-math.log(avg_it_time, 10)) + 1 self._logger.debug(f"the power is {power}") self.async_mem_monitor.set_interval(power) @property def refreshrate(self) -> int: return self._refreshrate @property def warmup(self) -> int: return self._warmup @property def curiter(self) -> int: return self._curiter @property def valid_iter(self) -> int: return self.curiter - self.warmup def pre_fwd_exec(self, module: torch.nn.Module, *args): if self._isvalid(module): self.async_mem_monitor.finish() self.async_mem_monitor.start() def post_fwd_exec(self, module: torch.nn.Module, *args): if self._isvalid(module): self.async_mem_monitor.finish() def pre_bwd_exec(self, module: torch.nn.Module, input, output): if self._isvalid(module): self.async_mem_monitor.finish() self.async_mem_monitor.start() def post_bwd_exec(self, module: torch.nn.Module, input): if self._isvalid(module): self.async_mem_monitor.finish() def pre_iter(self): pass def post_iter(self): self.async_mem_monitor.finish() # in the warmup stage if self.curiter < self.warmup: pass # adjust the sampling rate elif self.curiter == self.warmup: # use adaptive sample rate self._resample() # record data to log file else: # every `refreshrate` times, refresh the file if self.valid_iter != 0 and self.valid_iter % self.refreshrate == 0: # output file info self._logger.info(f"dump a memory statistics as pickle to {self._data_prefix}-{self._rank}.pkl") home_dir = Path.home() with open(home_dir.joinpath(f".cache/colossal/mem-{self._rank}.pkl"), "wb") as f: pickle.dump(self.async_mem_monitor.state_dict, f) self._count += 1 self._logger.debug(f"data file has been refreshed {self._count} times") # finish a iteration self._curiter += 1 def save_results(self, data_file: Union[str, Path]): with open(data_file, "w") as f: f.write(json.dumps(self.async_mem_monitor.state_dict))