From 340e59f96818355c860e6feb54d32df8346f53ca Mon Sep 17 00:00:00 2001 From: HELSON Date: Wed, 13 Apr 2022 10:50:54 +0800 Subject: [PATCH] [utils] add synchronized cuda memory monitor (#740) --- colossalai/trainer/hooks/_mem_tracer_hook.py | 8 +- colossalai/utils/memory_tracer/__init__.py | 4 +- .../{async_memtracer.py => memory_monitor.py} | 245 ++++++++++-------- .../utils/memory_tracer/memstats_collector.py | 2 +- 4 files changed, 149 insertions(+), 110 deletions(-) rename colossalai/utils/memory_tracer/{async_memtracer.py => memory_monitor.py} (67%) diff --git a/colossalai/trainer/hooks/_mem_tracer_hook.py b/colossalai/trainer/hooks/_mem_tracer_hook.py index 4f86e156e..9ce3239dd 100644 --- a/colossalai/trainer/hooks/_mem_tracer_hook.py +++ b/colossalai/trainer/hooks/_mem_tracer_hook.py @@ -1,9 +1,8 @@ -from cgitb import Hook from colossalai.registry import HOOKS from torch import Tensor from colossalai.trainer.hooks import BaseHook from colossalai.utils.memory_tracer import AsyncMemoryMonitor -from ._metric_hook import LearningRateMetric, MetricHook + @HOOKS.register_module class MemTraceHook(BaseHook): @@ -11,6 +10,7 @@ class MemTraceHook(BaseHook): This hook is used to record memory usage info, and pass to trainer.states You can use it as other trainer hook and fetch data from trainer.states['metrics][mode] """ + def __init__( self, priority: int = 0, @@ -36,9 +36,9 @@ class MemTraceHook(BaseHook): def before_test_iter(self, trainer): self._memory_monitor.start() return super().before_test(trainer) - + def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): self._memory_monitor.finish() trainer.states['metrics']['train'] = self._memory_monitor.state_dict trainer.states['metrics']['test'] = self._memory_monitor.state_dict - return super().after_test_iter(trainer, output, label, loss) \ No newline at end of file + return super().after_test_iter(trainer, output, label, loss) diff --git a/colossalai/utils/memory_tracer/__init__.py b/colossalai/utils/memory_tracer/__init__.py index 3f4cd66b8..f57e40e56 100644 --- a/colossalai/utils/memory_tracer/__init__.py +++ b/colossalai/utils/memory_tracer/__init__.py @@ -1,4 +1,4 @@ -from .async_memtracer import AsyncMemoryMonitor +from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor from .memstats_collector import MemStatsCollector -__all__ = ['AsyncMemoryMonitor', 'MemStatsCollector'] +__all__ = ['AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector'] diff --git a/colossalai/utils/memory_tracer/async_memtracer.py b/colossalai/utils/memory_tracer/memory_monitor.py similarity index 67% rename from colossalai/utils/memory_tracer/async_memtracer.py rename to colossalai/utils/memory_tracer/memory_monitor.py index 4442a9e51..00a6a3176 100644 --- a/colossalai/utils/memory_tracer/async_memtracer.py +++ b/colossalai/utils/memory_tracer/memory_monitor.py @@ -1,103 +1,142 @@ -from concurrent.futures import ThreadPoolExecutor -from time import sleep, time -import pickle - -import torch - -from colossalai.utils.memory import colo_device_memory_used -from colossalai.utils import get_current_device - - -class AsyncMemoryMonitor: - """ - An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU - at interval of `1/(10**power)` sec. - - The idea comes from Runtime Memory Tracer of PatrickStar - `PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_ - - Usage:: - - async_mem_monitor = AsyncMemoryMonitor() - input = torch.randn(2, 20).cuda() - OP1 = torch.nn.Linear(20, 30).cuda() - OP2 = torch.nn.Linear(30, 40).cuda() - - async_mem_monitor.start() - output = OP1(input) - async_mem_monitor.finish() - async_mem_monitor.start() - output = OP2(output) - async_mem_monitor.finish() - async_mem_monitor.save('log.pkl') - - - Args: - power (int, optional): the power of time interva. Defaults to 10. - - .. _PatrickStar\: Parallel Training of Pre-trained Models via Chunk-based Memory Management: - https://arxiv.org/abs/2108.05818 - """ - - def __init__(self, power: int = 10): - self.keep_measuring = False - - current_device = get_current_device() - - def _set_cuda_device(): - torch.cuda.set_device(current_device) - - self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device) - self.monitor_thread = None - self.interval = 1 / (10**power) - self.time_stamps = [] - self.mem_stats = [] - - def __len__(self): - return len(self.mem_stats) - - def set_interval(self, power: int): - self.clear() - self.interval = 1 / (10**power) - - def is_measuring(self): - return self.keep_measuring - - def start(self): - self.keep_measuring = True - self.monitor_thread = self.executor.submit(self._measure_usage) - - def finish(self): - if self.keep_measuring is False: - return 0 - self.keep_measuring = False - max_usage = self.monitor_thread.result() - self.monitor_thread = None - self.time_stamps.append(time()) - self.mem_stats.append(max_usage) - return max_usage - - def _measure_usage(self): - max_usage = 0 - while self.keep_measuring: - max_usage = max( - max_usage, - colo_device_memory_used(get_current_device()), - ) - sleep(self.interval) - return max_usage - - @property - def state_dict(self): - return { - "time_stamps": self.time_stamps, - "mem_stats": self.mem_stats, - } - - def save(self, filename): - with open(filename, "wb") as f: - pickle.dump(self.state_dict(), f) - - def clear(self): - self.mem_stats.clear() - self.time_stamps.clear() +from abc import abstractmethod +from concurrent.futures import ThreadPoolExecutor +from time import sleep, time +import json + +import torch + +from colossalai.utils.memory import colo_device_memory_used +from colossalai.utils import get_current_device + + +class MemoryMonitor: + """Base class for all types of memory monitor. + All monitors should have a list called `time_stamps` and a list called `mem_stats`. + """ + + def __init__(self): + self.time_stamps = [] + self.mem_stats = [] + + def __len__(self): + return len(self.mem_stats) + + @abstractmethod + def start(self): + pass + + @abstractmethod + def finish(self): + pass + + def state_dict(self): + return { + "time_stamps": self.time_stamps, + "mem_stats": self.mem_stats, + } + + def save(self, filename): + with open(filename, "w") as f: + json.dump(self.state_dict(), f) + + def clear(self): + self.mem_stats.clear() + self.time_stamps.clear() + + +class AsyncMemoryMonitor(MemoryMonitor): + """ + An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU + at interval of `1/(10**power)` sec. + + The idea comes from Runtime Memory Tracer of PatrickStar + `PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_ + + Usage:: + + async_mem_monitor = AsyncMemoryMonitor() + input = torch.randn(2, 20).cuda() + OP1 = torch.nn.Linear(20, 30).cuda() + OP2 = torch.nn.Linear(30, 40).cuda() + + async_mem_monitor.start() + output = OP1(input) + async_mem_monitor.finish() + async_mem_monitor.start() + output = OP2(output) + async_mem_monitor.finish() + async_mem_monitor.save('log.pkl') + + Args: + power (int, optional): the power of time interva. Defaults to 10. + + .. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management: + https://arxiv.org/abs/2108.05818 + """ + + def __init__(self, power: int = 10): + super().__init__() + self.keep_measuring = False + + current_device = get_current_device() + + def _set_cuda_device(): + torch.cuda.set_device(current_device) + + self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device) + self.monitor_thread = None + self.interval = 1 / (10**power) + + def set_interval(self, power: int): + self.clear() + self.interval = 1 / (10**power) + + def is_measuring(self): + return self.keep_measuring + + def start(self): + self.keep_measuring = True + self.monitor_thread = self.executor.submit(self._measure_usage) + + def finish(self): + if self.keep_measuring is False: + return 0 + + self.keep_measuring = False + max_usage = self.monitor_thread.result() + + self.monitor_thread = None + self.time_stamps.append(time()) + self.mem_stats.append(max_usage) + return max_usage + + def _measure_usage(self): + max_usage = 0 + while self.keep_measuring: + max_usage = max( + max_usage, + colo_device_memory_used(get_current_device()), + ) + sleep(self.interval) + return max_usage + + +class SyncCudaMemoryMonitor(MemoryMonitor): + """ + A synchronized cuda memory monitor. + It only record the maximum allocated cuda memory from start point to finish point. + """ + + def __init__(self, power: int = 10): + super().__init__() + + def start(self): + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + def finish(self): + torch.cuda.synchronize() + self.time_stamps.append(time()) + max_usage = torch.cuda.max_memory_allocated() + self.mem_stats.append(max_usage) + return max_usage diff --git a/colossalai/utils/memory_tracer/memstats_collector.py b/colossalai/utils/memory_tracer/memstats_collector.py index 5da971ab5..2aa32b829 100644 --- a/colossalai/utils/memory_tracer/memstats_collector.py +++ b/colossalai/utils/memory_tracer/memstats_collector.py @@ -1,6 +1,6 @@ from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory import colo_device_memory_used -from colossalai.utils.memory_tracer.async_memtracer import AsyncMemoryMonitor +from colossalai.utils.memory_tracer import AsyncMemoryMonitor import torch import time from typing import List