[utils] add synchronized cuda memory monitor (#740)

pull/742/head
HELSON 2022-04-13 10:50:54 +08:00 committed by GitHub
parent e6212f56cd
commit 340e59f968
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 149 additions and 110 deletions

View File

@ -1,9 +1,8 @@
from cgitb import Hook
from colossalai.registry import HOOKS from colossalai.registry import HOOKS
from torch import Tensor from torch import Tensor
from colossalai.trainer.hooks import BaseHook from colossalai.trainer.hooks import BaseHook
from colossalai.utils.memory_tracer import AsyncMemoryMonitor from colossalai.utils.memory_tracer import AsyncMemoryMonitor
from ._metric_hook import LearningRateMetric, MetricHook
@HOOKS.register_module @HOOKS.register_module
class MemTraceHook(BaseHook): class MemTraceHook(BaseHook):
@ -11,6 +10,7 @@ class MemTraceHook(BaseHook):
This hook is used to record memory usage info, and pass to trainer.states This hook is used to record memory usage info, and pass to trainer.states
You can use it as other trainer hook and fetch data from trainer.states['metrics][mode] You can use it as other trainer hook and fetch data from trainer.states['metrics][mode]
""" """
def __init__( def __init__(
self, self,
priority: int = 0, priority: int = 0,
@ -36,9 +36,9 @@ class MemTraceHook(BaseHook):
def before_test_iter(self, trainer): def before_test_iter(self, trainer):
self._memory_monitor.start() self._memory_monitor.start()
return super().before_test(trainer) return super().before_test(trainer)
def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
self._memory_monitor.finish() self._memory_monitor.finish()
trainer.states['metrics']['train'] = self._memory_monitor.state_dict trainer.states['metrics']['train'] = self._memory_monitor.state_dict
trainer.states['metrics']['test'] = self._memory_monitor.state_dict trainer.states['metrics']['test'] = self._memory_monitor.state_dict
return super().after_test_iter(trainer, output, label, loss) return super().after_test_iter(trainer, output, label, loss)

View File

@ -1,4 +1,4 @@
from .async_memtracer import AsyncMemoryMonitor from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor
from .memstats_collector import MemStatsCollector from .memstats_collector import MemStatsCollector
__all__ = ['AsyncMemoryMonitor', 'MemStatsCollector'] __all__ = ['AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector']

View File

@ -1,103 +1,142 @@
from concurrent.futures import ThreadPoolExecutor from abc import abstractmethod
from time import sleep, time from concurrent.futures import ThreadPoolExecutor
import pickle from time import sleep, time
import json
import torch
import torch
from colossalai.utils.memory import colo_device_memory_used
from colossalai.utils import get_current_device from colossalai.utils.memory import colo_device_memory_used
from colossalai.utils import get_current_device
class AsyncMemoryMonitor:
""" class MemoryMonitor:
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU """Base class for all types of memory monitor.
at interval of `1/(10**power)` sec. All monitors should have a list called `time_stamps` and a list called `mem_stats`.
"""
The idea comes from Runtime Memory Tracer of PatrickStar
`PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_ def __init__(self):
self.time_stamps = []
Usage:: self.mem_stats = []
async_mem_monitor = AsyncMemoryMonitor() def __len__(self):
input = torch.randn(2, 20).cuda() return len(self.mem_stats)
OP1 = torch.nn.Linear(20, 30).cuda()
OP2 = torch.nn.Linear(30, 40).cuda() @abstractmethod
def start(self):
async_mem_monitor.start() pass
output = OP1(input)
async_mem_monitor.finish() @abstractmethod
async_mem_monitor.start() def finish(self):
output = OP2(output) pass
async_mem_monitor.finish()
async_mem_monitor.save('log.pkl') def state_dict(self):
return {
"time_stamps": self.time_stamps,
Args: "mem_stats": self.mem_stats,
power (int, optional): the power of time interva. Defaults to 10. }
.. _PatrickStar\: Parallel Training of Pre-trained Models via Chunk-based Memory Management: def save(self, filename):
https://arxiv.org/abs/2108.05818 with open(filename, "w") as f:
""" json.dump(self.state_dict(), f)
def __init__(self, power: int = 10): def clear(self):
self.keep_measuring = False self.mem_stats.clear()
self.time_stamps.clear()
current_device = get_current_device()
def _set_cuda_device(): class AsyncMemoryMonitor(MemoryMonitor):
torch.cuda.set_device(current_device) """
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device) at interval of `1/(10**power)` sec.
self.monitor_thread = None
self.interval = 1 / (10**power) The idea comes from Runtime Memory Tracer of PatrickStar
self.time_stamps = [] `PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_
self.mem_stats = []
Usage::
def __len__(self):
return len(self.mem_stats) async_mem_monitor = AsyncMemoryMonitor()
input = torch.randn(2, 20).cuda()
def set_interval(self, power: int): OP1 = torch.nn.Linear(20, 30).cuda()
self.clear() OP2 = torch.nn.Linear(30, 40).cuda()
self.interval = 1 / (10**power)
async_mem_monitor.start()
def is_measuring(self): output = OP1(input)
return self.keep_measuring async_mem_monitor.finish()
async_mem_monitor.start()
def start(self): output = OP2(output)
self.keep_measuring = True async_mem_monitor.finish()
self.monitor_thread = self.executor.submit(self._measure_usage) async_mem_monitor.save('log.pkl')
def finish(self): Args:
if self.keep_measuring is False: power (int, optional): the power of time interva. Defaults to 10.
return 0
self.keep_measuring = False .. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
max_usage = self.monitor_thread.result() https://arxiv.org/abs/2108.05818
self.monitor_thread = None """
self.time_stamps.append(time())
self.mem_stats.append(max_usage) def __init__(self, power: int = 10):
return max_usage super().__init__()
self.keep_measuring = False
def _measure_usage(self):
max_usage = 0 current_device = get_current_device()
while self.keep_measuring:
max_usage = max( def _set_cuda_device():
max_usage, torch.cuda.set_device(current_device)
colo_device_memory_used(get_current_device()),
) self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device)
sleep(self.interval) self.monitor_thread = None
return max_usage self.interval = 1 / (10**power)
@property def set_interval(self, power: int):
def state_dict(self): self.clear()
return { self.interval = 1 / (10**power)
"time_stamps": self.time_stamps,
"mem_stats": self.mem_stats, def is_measuring(self):
} return self.keep_measuring
def save(self, filename): def start(self):
with open(filename, "wb") as f: self.keep_measuring = True
pickle.dump(self.state_dict(), f) self.monitor_thread = self.executor.submit(self._measure_usage)
def clear(self): def finish(self):
self.mem_stats.clear() if self.keep_measuring is False:
self.time_stamps.clear() return 0
self.keep_measuring = False
max_usage = self.monitor_thread.result()
self.monitor_thread = None
self.time_stamps.append(time())
self.mem_stats.append(max_usage)
return max_usage
def _measure_usage(self):
max_usage = 0
while self.keep_measuring:
max_usage = max(
max_usage,
colo_device_memory_used(get_current_device()),
)
sleep(self.interval)
return max_usage
class SyncCudaMemoryMonitor(MemoryMonitor):
"""
A synchronized cuda memory monitor.
It only record the maximum allocated cuda memory from start point to finish point.
"""
def __init__(self, power: int = 10):
super().__init__()
def start(self):
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
def finish(self):
torch.cuda.synchronize()
self.time_stamps.append(time())
max_usage = torch.cuda.max_memory_allocated()
self.mem_stats.append(max_usage)
return max_usage

View File

@ -1,6 +1,6 @@
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory import colo_device_memory_used from colossalai.utils.memory import colo_device_memory_used
from colossalai.utils.memory_tracer.async_memtracer import AsyncMemoryMonitor from colossalai.utils.memory_tracer import AsyncMemoryMonitor
import torch import torch
import time import time
from typing import List from typing import List