You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/engine/ophooks/_memtracer_ophook.py

118 lines
4.1 KiB

import json
import pickle
from pathlib import Path
from colossalai.context.parallel_mode import ParallelMode
import torch
from colossalai.engine.ophooks import BaseOpHook
from colossalai.registry import OPHOOKS
from colossalai.logging import get_dist_logger
from colossalai.core import global_context as gpc
from typing import Union
import math
@OPHOOKS.register_module
class MemTracerOpHook(BaseOpHook):
"""
Collect GPU memory usage information
Args:
warmup (int): This parameter indicates how many iterations to truncate before profiling, defaults to 50.
refreshrate (int): This parameter decides the frequency of write file, defaults to 10.
data_prefix (string): The prefix of the stats data file, defaults to "memstats".
"""
def __init__(self, warmup: int = 50, refreshrate: int = 10, data_prefix: str = "memstats"):
from colossalai.gemini.memory_tracer import AsyncMemoryMonitor
super().__init__()
self.async_mem_monitor = AsyncMemoryMonitor()
self._curiter = 0
self._logger = get_dist_logger()
self._count = 0
self._warmup = warmup
self._refreshrate = refreshrate
self._data_prefix = data_prefix
# in distributed environment
if gpc.is_initialized(ParallelMode.GLOBAL):
self._rank = gpc.get_global_rank()
else:
self._rank = 0
def _isvalid(self, module) -> bool:
assert isinstance(module, torch.nn.Module)
return module.training
def _resample(self):
# calculate the average iteration time
total_time = (self.async_mem_monitor.time_stamps[-1] - self.async_mem_monitor.time_stamps[0])
avg_it_time = total_time / self.warmup
self._logger.debug(f"total time for {self.warmup} iterations is {total_time}s")
# adjust the sampling power
power: int = round(-math.log(avg_it_time, 10)) + 1
self._logger.debug(f"the power is {power}")
self.async_mem_monitor.set_interval(power)
@property
def refreshrate(self) -> int:
return self._refreshrate
@property
def warmup(self) -> int:
return self._warmup
@property
def curiter(self) -> int:
return self._curiter
@property
def valid_iter(self) -> int:
return self.curiter - self.warmup
def pre_fwd_exec(self, module: torch.nn.Module, *args):
if self._isvalid(module):
self.async_mem_monitor.finish()
self.async_mem_monitor.start()
def post_fwd_exec(self, module: torch.nn.Module, *args):
if self._isvalid(module):
self.async_mem_monitor.finish()
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
if self._isvalid(module):
self.async_mem_monitor.finish()
self.async_mem_monitor.start()
def post_bwd_exec(self, module: torch.nn.Module, input):
if self._isvalid(module):
self.async_mem_monitor.finish()
def pre_iter(self):
pass
def post_iter(self):
self.async_mem_monitor.finish()
# in the warmup stage
if self.curiter < self.warmup:
pass
# adjust the sampling rate
elif self.curiter == self.warmup:
# use adaptive sample rate
self._resample()
# record data to log file
else:
# every `refreshrate` times, refresh the file
if self.valid_iter != 0 and self.valid_iter % self.refreshrate == 0:
# output file info
self._logger.info(f"dump a memory statistics as pickle to {self._data_prefix}-{self._rank}.pkl")
home_dir = Path.home()
with open(home_dir.joinpath(f".cache/colossal/mem-{self._rank}.pkl"), "wb") as f:
pickle.dump(self.async_mem_monitor.state_dict, f)
self._count += 1
self._logger.debug(f"data file has been refreshed {self._count} times")
# finish a iteration
self._curiter += 1
def save_results(self, data_file: Union[str, Path]):
with open(data_file, "w") as f:
f.write(json.dumps(self.async_mem_monitor.state_dict))