From 73d36618a6263b3e0ff3178a4e7dfaaec60655b0 Mon Sep 17 00:00:00 2001 From: Jie Zhu Date: Tue, 29 Mar 2022 12:48:34 +0800 Subject: [PATCH] [profiler] add MemProfiler (#356) * add memory trainer hook * fix bug * add memory trainer hook * fix import bug * fix import bug * add trainer hook * fix #370 git log bug * modify `to_tensorboard` function to support better output * remove useless output * change the name of `MemProfiler` * complete memory profiler * replace error with warning * finish trainer hook * modify interface of MemProfiler * modify `__init__.py` in profiler * remove unnecessary pass statement * add usage to doc string * add usage to trainer hook * new location to store temp data file --- README-zh-Hans.md | 2 +- README.md | 2 +- colossalai/engine/_base_engine.py | 27 ++++++++-- .../engine/ophooks/_memtracer_ophook.py | 17 ++++--- colossalai/trainer/hooks/_mem_tracer_hook.py | 44 ++++++++++++++++ .../utils/memory_tracer/async_memtracer.py | 2 +- colossalai/utils/profiler/__init__.py | 5 +- colossalai/utils/profiler/mem_profiler.py | 50 +++++++++++++++++++ 8 files changed, 136 insertions(+), 13 deletions(-) create mode 100644 colossalai/trainer/hooks/_mem_tracer_hook.py create mode 100644 colossalai/utils/profiler/mem_profiler.py diff --git a/README-zh-Hans.md b/README-zh-Hans.md index a3a5a81ac..7ccdeaa19 100644 --- a/README-zh-Hans.md +++ b/README-zh-Hans.md @@ -267,4 +267,4 @@ class MLP_2D(nn.Module): } ``` -

(返回顶端)

+

(返回顶端)

\ No newline at end of file diff --git a/README.md b/README.md index 8bab615d0..400116d1d 100644 --- a/README.md +++ b/README.md @@ -270,4 +270,4 @@ class MLP_2D(nn.Module): } ``` -

(back to top)

+

(back to top)

\ No newline at end of file diff --git a/colossalai/engine/_base_engine.py b/colossalai/engine/_base_engine.py index 699268cec..f8b2de86e 100644 --- a/colossalai/engine/_base_engine.py +++ b/colossalai/engine/_base_engine.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +from asyncio.log import logger from typing import List from torch.nn import Module from torch.nn.modules.loss import _Loss @@ -9,9 +10,9 @@ from torch.optim import Optimizer from colossalai.logging import get_dist_logger from torch import Tensor from colossalai.engine.ophooks import register_ophooks_recursively, BaseOpHook -from typing import Optional +from typing import Optional, Type from colossalai.engine.gradient_handler import BaseGradientHandler - +from colossalai.logging import get_dist_logger class Engine: """Basic engine class for training and evaluation. It runs a specific process method @@ -64,6 +65,11 @@ class Engine: self._ophook_list = ophook_list register_ophooks_recursively(self._model, self._ophook_list) + @property + def ophooks(self): + """show current activated ophooks""" + return self._ophook_list + @property def model(self): """Model attached to the engine""" @@ -79,6 +85,21 @@ class Engine: """Criterion attached to the engine""" return self._criterion + def add_hook(self, ophook: Type[BaseOpHook]) -> None: + """add necessary hook""" + # whether this hook exist + for h in self._ophook_list: + if type(h) == type(ophook): + logger = get_dist_logger() + logger.warning(f"duplicate hooks, at least two instance of {type(ophook)}") + self._ophook_list.append(ophook) + register_ophooks_recursively(self._model, self._ophook_list) + + def remove_hook(self, ophook: Type[BaseOpHook]) -> None: + """remove hook""" + logger = get_dist_logger() + logger.warning(f"removing hooks is currently not supported") + def zero_grad(self): """Set the gradient of parameters to zero """ @@ -150,4 +171,4 @@ class Engine: """Sets the model to evaluation mode. """ self.training = False - self._model.eval() + self._model.eval() \ No newline at end of file diff --git a/colossalai/engine/ophooks/_memtracer_ophook.py b/colossalai/engine/ophooks/_memtracer_ophook.py index c7b20c340..530f501a1 100644 --- a/colossalai/engine/ophooks/_memtracer_ophook.py +++ b/colossalai/engine/ophooks/_memtracer_ophook.py @@ -1,12 +1,15 @@ +import json +import pickle +from pathlib import Path from colossalai.context.parallel_mode import ParallelMode import torch from colossalai.engine.ophooks import BaseOpHook from colossalai.registry import OPHOOKS from colossalai.logging import get_dist_logger from colossalai.core import global_context as gpc - +from typing import Union from colossalai.utils.memory_tracer import AsyncMemoryMonitor - +import os import math @@ -103,12 +106,14 @@ class MemTracerOpHook(BaseOpHook): if self.valid_iter != 0 and self.valid_iter % self.refreshrate == 0: # output file info self._logger.info(f"dump a memory statistics as pickle to {self._data_prefix}-{self._rank}.pkl") - self.save_results() + home_dir = Path.home() + with open (home_dir.joinpath(f".cache/colossal/mem-{self._rank}.pkl"), "wb") as f: + pickle.dump(self.async_mem_monitor.state_dict, f) self._count += 1 self._logger.debug(f"data file has been refreshed {self._count} times") # finish a iteration self._curiter += 1 - def save_results(self): - datafile = f"{self._data_prefix}-{self._rank}.pkl" - self.async_mem_monitor.save(datafile) \ No newline at end of file + def save_results(self, data_file: Union[str, Path]): + with open(data_file, "w") as f: + f.write(json.dumps(self.async_mem_monitor.state_dict)) \ No newline at end of file diff --git a/colossalai/trainer/hooks/_mem_tracer_hook.py b/colossalai/trainer/hooks/_mem_tracer_hook.py new file mode 100644 index 000000000..4f86e156e --- /dev/null +++ b/colossalai/trainer/hooks/_mem_tracer_hook.py @@ -0,0 +1,44 @@ +from cgitb import Hook +from colossalai.registry import HOOKS +from torch import Tensor +from colossalai.trainer.hooks import BaseHook +from colossalai.utils.memory_tracer import AsyncMemoryMonitor +from ._metric_hook import LearningRateMetric, MetricHook + +@HOOKS.register_module +class MemTraceHook(BaseHook): + """Save memory stats and pass it to states + This hook is used to record memory usage info, and pass to trainer.states + You can use it as other trainer hook and fetch data from trainer.states['metrics][mode] + """ + def __init__( + self, + priority: int = 0, + ) -> None: + super().__init__(priority=priority) + self._memory_monitor = AsyncMemoryMonitor() + + def after_hook_is_attached(self, trainer): + # Initialize the data + trainer.states['metrics']['train'] = self._memory_monitor.state_dict + trainer.states['metrics']['test'] = self._memory_monitor.state_dict + + def before_train_iter(self, trainer): + self._memory_monitor.start() + return super().before_train_iter(trainer) + + def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): + self._memory_monitor.finish() + trainer.states['metrics']['train'] = self._memory_monitor.state_dict + trainer.states['metrics']['test'] = self._memory_monitor.state_dict + return super().after_train_iter(trainer, output, label, loss) + + def before_test_iter(self, trainer): + self._memory_monitor.start() + return super().before_test(trainer) + + def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): + self._memory_monitor.finish() + trainer.states['metrics']['train'] = self._memory_monitor.state_dict + trainer.states['metrics']['test'] = self._memory_monitor.state_dict + return super().after_test_iter(trainer, output, label, loss) \ No newline at end of file diff --git a/colossalai/utils/memory_tracer/async_memtracer.py b/colossalai/utils/memory_tracer/async_memtracer.py index 74dd278c8..69295a715 100644 --- a/colossalai/utils/memory_tracer/async_memtracer.py +++ b/colossalai/utils/memory_tracer/async_memtracer.py @@ -86,6 +86,7 @@ class AsyncMemoryMonitor: sleep(self.interval) return max_usage + @property def state_dict(self): return { "time_stamps": self.time_stamps, @@ -94,7 +95,6 @@ class AsyncMemoryMonitor: def save(self, filename): with open(filename, "wb") as f: - print(self.state_dict()) pickle.dump(self.state_dict(), f) def clear(self): diff --git a/colossalai/utils/profiler/__init__.py b/colossalai/utils/profiler/__init__.py index 0223e732f..810a43394 100644 --- a/colossalai/utils/profiler/__init__.py +++ b/colossalai/utils/profiler/__init__.py @@ -1,3 +1,6 @@ from .comm_profiler import CommProfiler from .pcie_profiler import PcieProfiler -from .prof_utils import ProfilerContext +from .prof_utils import ProfilerContext, BaseProfiler +from .mem_profiler import MemProfiler + +__all__ = ['BaseProfiler', 'CommProfiler', 'PcieProfiler', 'MemProfiler', 'ProfilerContext'] \ No newline at end of file diff --git a/colossalai/utils/profiler/mem_profiler.py b/colossalai/utils/profiler/mem_profiler.py new file mode 100644 index 000000000..b60e714c4 --- /dev/null +++ b/colossalai/utils/profiler/mem_profiler.py @@ -0,0 +1,50 @@ +from pathlib import Path +from typing import Union +from colossalai.engine import Engine +from torch.utils.tensorboard import SummaryWriter +from colossalai.engine.ophooks import MemTracerOpHook +from colossalai.utils.profiler import BaseProfiler + + +class MemProfiler(BaseProfiler): + """Wraper of MemOpHook, used to show GPU memory usage through each iteration + + To use this profiler, you need to pass an `engine` instance. And the usage is same like + CommProfiler. + + mm_prof = MemProfiler(engine) + with ProfilerContext([mm_prof]) as prof: + writer = SummaryWriter("mem") + engine.train() + ... + prof.to_file("./log") + prof.to_tensorboard(writer) + + """ + + def __init__(self, engine: Engine, warmup: int = 50, refreshrate: int = 10) -> None: + super().__init__(profiler_name="MemoryProfiler", priority=0) + self._mem_tracer = MemTracerOpHook(warmup=warmup, refreshrate=refreshrate) + self._engine = engine + + def enable(self) -> None: + self._engine.add_hook(self._mem_tracer) + + def disable(self) -> None: + self._engine.remove_hook(self._mem_tracer) + + def to_tensorboard(self, writer: SummaryWriter) -> None: + stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats'] + for info, i in enumerate(stats): + writer.add_scalar( + "memory_usage/GPU", + info, + i + ) + + def to_file(self, data_file: Path) -> None: + self._mem_tracer.save_results(data_file) + + def show(self) -> None: + stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats'] + print(stats)