[utils] add synchronized cuda memory monitor (#740)

pull/742/head
HELSON 2022-04-13 10:50:54 +08:00 committed by GitHub
parent e6212f56cd
commit 340e59f968
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 149 additions and 110 deletions

View File

@ -1,9 +1,8 @@
from cgitb import Hook
from colossalai.registry import HOOKS from colossalai.registry import HOOKS
from torch import Tensor from torch import Tensor
from colossalai.trainer.hooks import BaseHook from colossalai.trainer.hooks import BaseHook
from colossalai.utils.memory_tracer import AsyncMemoryMonitor from colossalai.utils.memory_tracer import AsyncMemoryMonitor
from ._metric_hook import LearningRateMetric, MetricHook
@HOOKS.register_module @HOOKS.register_module
class MemTraceHook(BaseHook): class MemTraceHook(BaseHook):
@ -11,6 +10,7 @@ class MemTraceHook(BaseHook):
This hook is used to record memory usage info, and pass to trainer.states This hook is used to record memory usage info, and pass to trainer.states
You can use it as other trainer hook and fetch data from trainer.states['metrics][mode] You can use it as other trainer hook and fetch data from trainer.states['metrics][mode]
""" """
def __init__( def __init__(
self, self,
priority: int = 0, priority: int = 0,

View File

@ -1,4 +1,4 @@
from .async_memtracer import AsyncMemoryMonitor from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor
from .memstats_collector import MemStatsCollector from .memstats_collector import MemStatsCollector
__all__ = ['AsyncMemoryMonitor', 'MemStatsCollector'] __all__ = ['AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector']

View File

@ -1,6 +1,7 @@
from abc import abstractmethod
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from time import sleep, time from time import sleep, time
import pickle import json
import torch import torch
@ -8,7 +9,42 @@ from colossalai.utils.memory import colo_device_memory_used
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
class AsyncMemoryMonitor: class MemoryMonitor:
"""Base class for all types of memory monitor.
All monitors should have a list called `time_stamps` and a list called `mem_stats`.
"""
def __init__(self):
self.time_stamps = []
self.mem_stats = []
def __len__(self):
return len(self.mem_stats)
@abstractmethod
def start(self):
pass
@abstractmethod
def finish(self):
pass
def state_dict(self):
return {
"time_stamps": self.time_stamps,
"mem_stats": self.mem_stats,
}
def save(self, filename):
with open(filename, "w") as f:
json.dump(self.state_dict(), f)
def clear(self):
self.mem_stats.clear()
self.time_stamps.clear()
class AsyncMemoryMonitor(MemoryMonitor):
""" """
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
at interval of `1/(10**power)` sec. at interval of `1/(10**power)` sec.
@ -31,15 +67,15 @@ class AsyncMemoryMonitor:
async_mem_monitor.finish() async_mem_monitor.finish()
async_mem_monitor.save('log.pkl') async_mem_monitor.save('log.pkl')
Args: Args:
power (int, optional): the power of time interva. Defaults to 10. power (int, optional): the power of time interva. Defaults to 10.
.. _PatrickStar\: Parallel Training of Pre-trained Models via Chunk-based Memory Management: .. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
https://arxiv.org/abs/2108.05818 https://arxiv.org/abs/2108.05818
""" """
def __init__(self, power: int = 10): def __init__(self, power: int = 10):
super().__init__()
self.keep_measuring = False self.keep_measuring = False
current_device = get_current_device() current_device = get_current_device()
@ -50,11 +86,6 @@ class AsyncMemoryMonitor:
self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device) self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device)
self.monitor_thread = None self.monitor_thread = None
self.interval = 1 / (10**power) self.interval = 1 / (10**power)
self.time_stamps = []
self.mem_stats = []
def __len__(self):
return len(self.mem_stats)
def set_interval(self, power: int): def set_interval(self, power: int):
self.clear() self.clear()
@ -70,8 +101,10 @@ class AsyncMemoryMonitor:
def finish(self): def finish(self):
if self.keep_measuring is False: if self.keep_measuring is False:
return 0 return 0
self.keep_measuring = False self.keep_measuring = False
max_usage = self.monitor_thread.result() max_usage = self.monitor_thread.result()
self.monitor_thread = None self.monitor_thread = None
self.time_stamps.append(time()) self.time_stamps.append(time())
self.mem_stats.append(max_usage) self.mem_stats.append(max_usage)
@ -87,17 +120,23 @@ class AsyncMemoryMonitor:
sleep(self.interval) sleep(self.interval)
return max_usage return max_usage
@property
def state_dict(self):
return {
"time_stamps": self.time_stamps,
"mem_stats": self.mem_stats,
}
def save(self, filename): class SyncCudaMemoryMonitor(MemoryMonitor):
with open(filename, "wb") as f: """
pickle.dump(self.state_dict(), f) A synchronized cuda memory monitor.
It only record the maximum allocated cuda memory from start point to finish point.
"""
def clear(self): def __init__(self, power: int = 10):
self.mem_stats.clear() super().__init__()
self.time_stamps.clear()
def start(self):
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
def finish(self):
torch.cuda.synchronize()
self.time_stamps.append(time())
max_usage = torch.cuda.max_memory_allocated()
self.mem_stats.append(max_usage)
return max_usage

View File

@ -1,6 +1,6 @@
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory import colo_device_memory_used from colossalai.utils.memory import colo_device_memory_used
from colossalai.utils.memory_tracer.async_memtracer import AsyncMemoryMonitor from colossalai.utils.memory_tracer import AsyncMemoryMonitor
import torch import torch
import time import time
from typing import List from typing import List