2022-03-29 04:48:34 +00:00
|
|
|
from pathlib import Path
|
|
|
|
from typing import Union
|
|
|
|
from colossalai.engine import Engine
|
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
2022-07-14 05:44:26 +00:00
|
|
|
from colossalai.gemini.ophooks import MemTracerOpHook
|
2022-04-24 09:03:59 +00:00
|
|
|
from colossalai.utils.profiler.legacy.prof_utils import BaseProfiler
|
2022-03-29 04:48:34 +00:00
|
|
|
|
|
|
|
|
|
|
|
class MemProfiler(BaseProfiler):
|
|
|
|
"""Wraper of MemOpHook, used to show GPU memory usage through each iteration
|
2022-04-01 08:36:47 +00:00
|
|
|
|
2022-03-29 04:48:34 +00:00
|
|
|
To use this profiler, you need to pass an `engine` instance. And the usage is same like
|
|
|
|
CommProfiler.
|
|
|
|
|
2022-04-01 08:36:47 +00:00
|
|
|
Usage::
|
|
|
|
|
2022-03-29 04:48:34 +00:00
|
|
|
mm_prof = MemProfiler(engine)
|
|
|
|
with ProfilerContext([mm_prof]) as prof:
|
|
|
|
writer = SummaryWriter("mem")
|
|
|
|
engine.train()
|
|
|
|
...
|
|
|
|
prof.to_file("./log")
|
|
|
|
prof.to_tensorboard(writer)
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, engine: Engine, warmup: int = 50, refreshrate: int = 10) -> None:
|
|
|
|
super().__init__(profiler_name="MemoryProfiler", priority=0)
|
|
|
|
self._mem_tracer = MemTracerOpHook(warmup=warmup, refreshrate=refreshrate)
|
|
|
|
self._engine = engine
|
|
|
|
|
|
|
|
def enable(self) -> None:
|
|
|
|
self._engine.add_hook(self._mem_tracer)
|
|
|
|
|
|
|
|
def disable(self) -> None:
|
|
|
|
self._engine.remove_hook(self._mem_tracer)
|
|
|
|
|
|
|
|
def to_tensorboard(self, writer: SummaryWriter) -> None:
|
|
|
|
stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats']
|
|
|
|
for info, i in enumerate(stats):
|
2022-04-01 08:36:47 +00:00
|
|
|
writer.add_scalar("memory_usage/GPU", info, i)
|
2022-03-29 04:48:34 +00:00
|
|
|
|
|
|
|
def to_file(self, data_file: Path) -> None:
|
|
|
|
self._mem_tracer.save_results(data_file)
|
|
|
|
|
|
|
|
def show(self) -> None:
|
2022-04-01 08:36:47 +00:00
|
|
|
stats = self._mem_tracer.async_mem_monitor.state_dict['mem_stats']
|
2022-03-29 04:48:34 +00:00
|
|
|
print(stats)
|