mirror of https://github.com/hpcaitech/ColossalAI
[profiler] add MemProfiler (#356)
* add memory trainer hook * fix bug * add memory trainer hook * fix import bug * fix import bug * add trainer hook * fix #370 git log bug * modify `to_tensorboard` function to support better output * remove useless output * change the name of `MemProfiler` * complete memory profiler * replace error with warning * finish trainer hook * modify interface of MemProfiler * modify `__init__.py` in profiler * remove unnecessary pass statement * add usage to doc string * add usage to trainer hook * new location to store temp data filepull/545/head^2
parent
fb841dd5c5
commit
73d36618a6
|
@ -1,6 +1,7 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
from asyncio.log import logger
|
||||||
from typing import List
|
from typing import List
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
from torch.nn.modules.loss import _Loss
|
from torch.nn.modules.loss import _Loss
|
||||||
|
@ -9,9 +10,9 @@ from torch.optim import Optimizer
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
|
from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook
|
||||||
from typing import Optional
|
from typing import Optional, Type
|
||||||
from colossalai.engine.gradient_handler import BaseGradientHandler
|
from colossalai.engine.gradient_handler import BaseGradientHandler
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
class Engine:
|
class Engine:
|
||||||
"""Basic engine class for training and evaluation. It runs a specific process method
|
"""Basic engine class for training and evaluation. It runs a specific process method
|
||||||
|
@ -64,6 +65,11 @@ class Engine:
|
||||||
self._ophook_list = ophook_list
|
self._ophook_list = ophook_list
|
||||||
register_ophooks_recursively(self._model, self._ophook_list)
|
register_ophooks_recursively(self._model, self._ophook_list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def ophooks(self):
|
||||||
|
"""show current activated ophooks"""
|
||||||
|
return self._ophook_list
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model(self):
|
def model(self):
|
||||||
"""Model attached to the engine"""
|
"""Model attached to the engine"""
|
||||||
|
@ -79,6 +85,21 @@ class Engine:
|
||||||
"""Criterion attached to the engine"""
|
"""Criterion attached to the engine"""
|
||||||
return self._criterion
|
return self._criterion
|
||||||
|
|
||||||
|
def add_hook(self, ophook: Type[BaseOpHook]) -> None:
|
||||||
|
"""add necessary hook"""
|
||||||
|
# whether this hook exist
|
||||||
|
for h in self._ophook_list:
|
||||||
|
if type(h) == type(ophook):
|
||||||
|
logger = get_dist_logger()
|
||||||
|
logger.warning(f"duplicate hooks, at least two instance of {type(ophook)}")
|
||||||
|
self._ophook_list.append(ophook)
|
||||||
|
register_ophooks_recursively(self._model, self._ophook_list)
|
||||||
|
|
||||||
|
def remove_hook(self, ophook: Type[BaseOpHook]) -> None:
|
||||||
|
"""remove hook"""
|
||||||
|
logger = get_dist_logger()
|
||||||
|
logger.warning(f"removing hooks is currently not supported")
|
||||||
|
|
||||||
def zero_grad(self):
|
def zero_grad(self):
|
||||||
"""Set the gradient of parameters to zero
|
"""Set the gradient of parameters to zero
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,12 +1,15 @@
|
||||||
|
import json
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
import torch
|
import torch
|
||||||
from colossalai.engine.ophooks import BaseOpHook
|
from colossalai.engine.ophooks import BaseOpHook
|
||||||
from colossalai.registry import OPHOOKS
|
from colossalai.registry import OPHOOKS
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
from typing import Union
|
||||||
from colossalai.utils.memory_tracer import AsyncMemoryMonitor
|
from colossalai.utils.memory_tracer import AsyncMemoryMonitor
|
||||||
|
import os
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
@ -103,12 +106,14 @@ class MemTracerOpHook(BaseOpHook):
|
||||||
if self.valid_iter != 0 and self.valid_iter % self.refreshrate == 0:
|
if self.valid_iter != 0 and self.valid_iter % self.refreshrate == 0:
|
||||||
# output file info
|
# output file info
|
||||||
self._logger.info(f"dump a memory statistics as pickle to {self._data_prefix}-{self._rank}.pkl")
|
self._logger.info(f"dump a memory statistics as pickle to {self._data_prefix}-{self._rank}.pkl")
|
||||||
self.save_results()
|
home_dir = Path.home()
|
||||||
|
with open (home_dir.joinpath(f".cache/colossal/mem-{self._rank}.pkl"), "wb") as f:
|
||||||
|
pickle.dump(self.async_mem_monitor.state_dict, f)
|
||||||
self._count += 1
|
self._count += 1
|
||||||
self._logger.debug(f"data file has been refreshed {self._count} times")
|
self._logger.debug(f"data file has been refreshed {self._count} times")
|
||||||
# finish a iteration
|
# finish a iteration
|
||||||
self._curiter += 1
|
self._curiter += 1
|
||||||
|
|
||||||
def save_results(self):
|
def save_results(self, data_file: Union[str, Path]):
|
||||||
datafile = f"{self._data_prefix}-{self._rank}.pkl"
|
with open(data_file, "w") as f:
|
||||||
self.async_mem_monitor.save(datafile)
|
f.write(json.dumps(self.async_mem_monitor.state_dict))
|
|
@ -0,0 +1,44 @@
|
||||||
|
from cgitb import Hook
|
||||||
|
from colossalai.registry import HOOKS
|
||||||
|
from torch import Tensor
|
||||||
|
from colossalai.trainer.hooks import BaseHook
|
||||||
|
from colossalai.utils.memory_tracer import AsyncMemoryMonitor
|
||||||
|
from ._metric_hook import LearningRateMetric, MetricHook
|
||||||
|
|
||||||
|
@HOOKS.register_module
|
||||||
|
class MemTraceHook(BaseHook):
|
||||||
|
"""Save memory stats and pass it to states
|
||||||
|
This hook is used to record memory usage info, and pass to trainer.states
|
||||||
|
You can use it as other trainer hook and fetch data from trainer.states['metrics][mode]
|
||||||
|
"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
priority: int = 0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(priority=priority)
|
||||||
|
self._memory_monitor = AsyncMemoryMonitor()
|
||||||
|
|
||||||
|
def after_hook_is_attached(self, trainer):
|
||||||
|
# Initialize the data
|
||||||
|
trainer.states['metrics']['train'] = self._memory_monitor.state_dict
|
||||||
|
trainer.states['metrics']['test'] = self._memory_monitor.state_dict
|
||||||
|
|
||||||
|
def before_train_iter(self, trainer):
|
||||||
|
self._memory_monitor.start()
|
||||||
|
return super().before_train_iter(trainer)
|
||||||
|
|
||||||
|
def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
|
||||||
|
self._memory_monitor.finish()
|
||||||
|
trainer.states['metrics']['train'] = self._memory_monitor.state_dict
|
||||||
|
trainer.states['metrics']['test'] = self._memory_monitor.state_dict
|
||||||
|
return super().after_train_iter(trainer, output, label, loss)
|
||||||
|
|
||||||
|
def before_test_iter(self, trainer):
|
||||||
|
self._memory_monitor.start()
|
||||||
|
return super().before_test(trainer)
|
||||||
|
|
||||||
|
def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
|
||||||
|
self._memory_monitor.finish()
|
||||||
|
trainer.states['metrics']['train'] = self._memory_monitor.state_dict
|
||||||
|
trainer.states['metrics']['test'] = self._memory_monitor.state_dict
|
||||||
|
return super().after_test_iter(trainer, output, label, loss)
|
|
@ -86,6 +86,7 @@ class AsyncMemoryMonitor:
|
||||||
sleep(self.interval)
|
sleep(self.interval)
|
||||||
return max_usage
|
return max_usage
|
||||||
|
|
||||||
|
@property
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {
|
return {
|
||||||
"time_stamps": self.time_stamps,
|
"time_stamps": self.time_stamps,
|
||||||
|
@ -94,7 +95,6 @@ class AsyncMemoryMonitor:
|
||||||
|
|
||||||
def save(self, filename):
|
def save(self, filename):
|
||||||
with open(filename, "wb") as f:
|
with open(filename, "wb") as f:
|
||||||
print(self.state_dict())
|
|
||||||
pickle.dump(self.state_dict(), f)
|
pickle.dump(self.state_dict(), f)
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
from .comm_profiler import CommProfiler
|
from .comm_profiler import CommProfiler
|
||||||
from .pcie_profiler import PcieProfiler
|
from .pcie_profiler import PcieProfiler
|
||||||
from .prof_utils import ProfilerContext
|
from .prof_utils import ProfilerContext, BaseProfiler
|
||||||
|
from .mem_profiler import MemProfiler
|
||||||
|
|
||||||
|
__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext']
|
|
@ -0,0 +1,50 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Union
|
||||||
|
from colossalai.engine import Engine
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from colossalai.engine.ophooks import MemTracerOpHook
|
||||||
|
from colossalai.utils.profiler import BaseProfiler
|
||||||
|
|
||||||
|
|
||||||
|
class MemProfiler(BaseProfiler):
|
||||||
|
"""Wraper of MemOpHook, used to show GPU memory usage through each iteration
|
||||||
|
|
||||||
|
To use this profiler, you need to pass an `engine` instance. And the usage is same like
|
||||||
|
CommProfiler.
|
||||||
|
|
||||||
|
mm_prof = MemProfiler(engine)
|
||||||
|
with ProfilerContext([mm_prof]) as prof:
|
||||||
|
writer = SummaryWriter("mem")
|
||||||
|
engine.train()
|
||||||
|
...
|
||||||
|
prof.to_file("./log")
|
||||||
|
prof.to_tensorboard(writer)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, engine: Engine, warmup: int = 50, refreshrate: int = 10) -> None:
|
||||||
|
super().__init__(profiler_name="MemoryProfiler", priority=0)
|
||||||
|
self._mem_tracer = MemTracerOpHook(warmup=warmup, refreshrate=refreshrate)
|
||||||
|
self._engine = engine
|
||||||
|
|
||||||
|
def enable(self) -> None:
|
||||||
|
self._engine.add_hook(self._mem_tracer)
|
||||||
|
|
||||||
|
def disable(self) -> None:
|
||||||
|
self._engine.remove_hook(self._mem_tracer)
|
||||||
|
|
||||||
|
def to_tensorboard(self, writer: SummaryWriter) -> None:
|
||||||
|
stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats']
|
||||||
|
for info, i in enumerate(stats):
|
||||||
|
writer.add_scalar(
|
||||||
|
"memory_usage/GPU",
|
||||||
|
info,
|
||||||
|
i
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_file(self, data_file: Path) -> None:
|
||||||
|
self._mem_tracer.save_results(data_file)
|
||||||
|
|
||||||
|
def show(self) -> None:
|
||||||
|
stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats']
|
||||||
|
print(stats)
|
Loading…
Reference in New Issue