[utils] add synchronized cuda memory monitor (#740)

pull/742/head
HELSON 2022-04-13 10:50:54 +08:00 committed by GitHub
parent e6212f56cd
commit 340e59f968
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 149 additions and 110 deletions

View File

@ -1,9 +1,8 @@
from cgitb import Hook
from colossalai.registry import HOOKS
from torch import Tensor
from colossalai.trainer.hooks import BaseHook
from colossalai.utils.memory_tracer import AsyncMemoryMonitor
from ._metric_hook import LearningRateMetric, MetricHook
@HOOKS.register_module
class MemTraceHook(BaseHook):
@ -11,6 +10,7 @@ class MemTraceHook(BaseHook):
This hook is used to record memory usage info, and pass to trainer.states
You can use it as other trainer hook and fetch data from trainer.states['metrics][mode]
"""
def __init__(
self,
priority: int = 0,

View File

@ -1,4 +1,4 @@
from .async_memtracer import AsyncMemoryMonitor
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor
from .memstats_collector import MemStatsCollector
__all__ = ['AsyncMemoryMonitor', 'MemStatsCollector']
__all__ = ['AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector']

View File

@ -1,6 +1,7 @@
from abc import abstractmethod
from concurrent.futures import ThreadPoolExecutor
from time import sleep, time
import pickle
import json
import torch
@ -8,7 +9,42 @@ from colossalai.utils.memory import colo_device_memory_used
from colossalai.utils import get_current_device
class AsyncMemoryMonitor:
class MemoryMonitor:
"""Base class for all types of memory monitor.
All monitors should have a list called `time_stamps` and a list called `mem_stats`.
"""
def __init__(self):
self.time_stamps = []
self.mem_stats = []
def __len__(self):
return len(self.mem_stats)
@abstractmethod
def start(self):
pass
@abstractmethod
def finish(self):
pass
def state_dict(self):
return {
"time_stamps": self.time_stamps,
"mem_stats": self.mem_stats,
}
def save(self, filename):
with open(filename, "w") as f:
json.dump(self.state_dict(), f)
def clear(self):
self.mem_stats.clear()
self.time_stamps.clear()
class AsyncMemoryMonitor(MemoryMonitor):
"""
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
at interval of `1/(10**power)` sec.
@ -31,15 +67,15 @@ class AsyncMemoryMonitor:
async_mem_monitor.finish()
async_mem_monitor.save('log.pkl')
Args:
power (int, optional): the power of time interva. Defaults to 10.
.. _PatrickStar\: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
.. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
https://arxiv.org/abs/2108.05818
"""
def __init__(self, power: int = 10):
super().__init__()
self.keep_measuring = False
current_device = get_current_device()
@ -50,11 +86,6 @@ class AsyncMemoryMonitor:
self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device)
self.monitor_thread = None
self.interval = 1 / (10**power)
self.time_stamps = []
self.mem_stats = []
def __len__(self):
return len(self.mem_stats)
def set_interval(self, power: int):
self.clear()
@ -70,8 +101,10 @@ class AsyncMemoryMonitor:
def finish(self):
if self.keep_measuring is False:
return 0
self.keep_measuring = False
max_usage = self.monitor_thread.result()
self.monitor_thread = None
self.time_stamps.append(time())
self.mem_stats.append(max_usage)
@ -87,17 +120,23 @@ class AsyncMemoryMonitor:
sleep(self.interval)
return max_usage
@property
def state_dict(self):
return {
"time_stamps": self.time_stamps,
"mem_stats": self.mem_stats,
}
def save(self, filename):
with open(filename, "wb") as f:
pickle.dump(self.state_dict(), f)
class SyncCudaMemoryMonitor(MemoryMonitor):
"""
A synchronized cuda memory monitor.
It only record the maximum allocated cuda memory from start point to finish point.
"""
def clear(self):
self.mem_stats.clear()
self.time_stamps.clear()
def __init__(self, power: int = 10):
super().__init__()
def start(self):
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
def finish(self):
torch.cuda.synchronize()
self.time_stamps.append(time())
max_usage = torch.cuda.max_memory_allocated()
self.mem_stats.append(max_usage)
return max_usage

View File

@ -1,6 +1,6 @@
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory import colo_device_memory_used
from colossalai.utils.memory_tracer.async_memtracer import AsyncMemoryMonitor
from colossalai.utils.memory_tracer import AsyncMemoryMonitor
import torch
import time
from typing import List