mirror of https://github.com/hpcaitech/ColossalAI
[utils] add synchronized cuda memory monitor (#740)
parent
e6212f56cd
commit
340e59f968
|
@ -1,9 +1,8 @@
|
||||||
from cgitb import Hook
|
|
||||||
from colossalai.registry import HOOKS
|
from colossalai.registry import HOOKS
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from colossalai.trainer.hooks import BaseHook
|
from colossalai.trainer.hooks import BaseHook
|
||||||
from colossalai.utils.memory_tracer import AsyncMemoryMonitor
|
from colossalai.utils.memory_tracer import AsyncMemoryMonitor
|
||||||
from ._metric_hook import LearningRateMetric, MetricHook
|
|
||||||
|
|
||||||
@HOOKS.register_module
|
@HOOKS.register_module
|
||||||
class MemTraceHook(BaseHook):
|
class MemTraceHook(BaseHook):
|
||||||
|
@ -11,6 +10,7 @@ class MemTraceHook(BaseHook):
|
||||||
This hook is used to record memory usage info, and pass to trainer.states
|
This hook is used to record memory usage info, and pass to trainer.states
|
||||||
You can use it as other trainer hook and fetch data from trainer.states['metrics][mode]
|
You can use it as other trainer hook and fetch data from trainer.states['metrics][mode]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
priority: int = 0,
|
priority: int = 0,
|
||||||
|
@ -36,9 +36,9 @@ class MemTraceHook(BaseHook):
|
||||||
def before_test_iter(self, trainer):
|
def before_test_iter(self, trainer):
|
||||||
self._memory_monitor.start()
|
self._memory_monitor.start()
|
||||||
return super().before_test(trainer)
|
return super().before_test(trainer)
|
||||||
|
|
||||||
def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
|
def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
|
||||||
self._memory_monitor.finish()
|
self._memory_monitor.finish()
|
||||||
trainer.states['metrics']['train'] = self._memory_monitor.state_dict
|
trainer.states['metrics']['train'] = self._memory_monitor.state_dict
|
||||||
trainer.states['metrics']['test'] = self._memory_monitor.state_dict
|
trainer.states['metrics']['test'] = self._memory_monitor.state_dict
|
||||||
return super().after_test_iter(trainer, output, label, loss)
|
return super().after_test_iter(trainer, output, label, loss)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from .async_memtracer import AsyncMemoryMonitor
|
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor
|
||||||
from .memstats_collector import MemStatsCollector
|
from .memstats_collector import MemStatsCollector
|
||||||
|
|
||||||
__all__ = ['AsyncMemoryMonitor', 'MemStatsCollector']
|
__all__ = ['AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector']
|
||||||
|
|
|
@ -1,103 +1,142 @@
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from abc import abstractmethod
|
||||||
from time import sleep, time
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
import pickle
|
from time import sleep, time
|
||||||
|
import json
|
||||||
import torch
|
|
||||||
|
import torch
|
||||||
from colossalai.utils.memory import colo_device_memory_used
|
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils.memory import colo_device_memory_used
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
class AsyncMemoryMonitor:
|
|
||||||
"""
|
class MemoryMonitor:
|
||||||
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
|
"""Base class for all types of memory monitor.
|
||||||
at interval of `1/(10**power)` sec.
|
All monitors should have a list called `time_stamps` and a list called `mem_stats`.
|
||||||
|
"""
|
||||||
The idea comes from Runtime Memory Tracer of PatrickStar
|
|
||||||
`PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_
|
def __init__(self):
|
||||||
|
self.time_stamps = []
|
||||||
Usage::
|
self.mem_stats = []
|
||||||
|
|
||||||
async_mem_monitor = AsyncMemoryMonitor()
|
def __len__(self):
|
||||||
input = torch.randn(2, 20).cuda()
|
return len(self.mem_stats)
|
||||||
OP1 = torch.nn.Linear(20, 30).cuda()
|
|
||||||
OP2 = torch.nn.Linear(30, 40).cuda()
|
@abstractmethod
|
||||||
|
def start(self):
|
||||||
async_mem_monitor.start()
|
pass
|
||||||
output = OP1(input)
|
|
||||||
async_mem_monitor.finish()
|
@abstractmethod
|
||||||
async_mem_monitor.start()
|
def finish(self):
|
||||||
output = OP2(output)
|
pass
|
||||||
async_mem_monitor.finish()
|
|
||||||
async_mem_monitor.save('log.pkl')
|
def state_dict(self):
|
||||||
|
return {
|
||||||
|
"time_stamps": self.time_stamps,
|
||||||
Args:
|
"mem_stats": self.mem_stats,
|
||||||
power (int, optional): the power of time interva. Defaults to 10.
|
}
|
||||||
|
|
||||||
.. _PatrickStar\: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
|
def save(self, filename):
|
||||||
https://arxiv.org/abs/2108.05818
|
with open(filename, "w") as f:
|
||||||
"""
|
json.dump(self.state_dict(), f)
|
||||||
|
|
||||||
def __init__(self, power: int = 10):
|
def clear(self):
|
||||||
self.keep_measuring = False
|
self.mem_stats.clear()
|
||||||
|
self.time_stamps.clear()
|
||||||
current_device = get_current_device()
|
|
||||||
|
|
||||||
def _set_cuda_device():
|
class AsyncMemoryMonitor(MemoryMonitor):
|
||||||
torch.cuda.set_device(current_device)
|
"""
|
||||||
|
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
|
||||||
self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device)
|
at interval of `1/(10**power)` sec.
|
||||||
self.monitor_thread = None
|
|
||||||
self.interval = 1 / (10**power)
|
The idea comes from Runtime Memory Tracer of PatrickStar
|
||||||
self.time_stamps = []
|
`PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_
|
||||||
self.mem_stats = []
|
|
||||||
|
Usage::
|
||||||
def __len__(self):
|
|
||||||
return len(self.mem_stats)
|
async_mem_monitor = AsyncMemoryMonitor()
|
||||||
|
input = torch.randn(2, 20).cuda()
|
||||||
def set_interval(self, power: int):
|
OP1 = torch.nn.Linear(20, 30).cuda()
|
||||||
self.clear()
|
OP2 = torch.nn.Linear(30, 40).cuda()
|
||||||
self.interval = 1 / (10**power)
|
|
||||||
|
async_mem_monitor.start()
|
||||||
def is_measuring(self):
|
output = OP1(input)
|
||||||
return self.keep_measuring
|
async_mem_monitor.finish()
|
||||||
|
async_mem_monitor.start()
|
||||||
def start(self):
|
output = OP2(output)
|
||||||
self.keep_measuring = True
|
async_mem_monitor.finish()
|
||||||
self.monitor_thread = self.executor.submit(self._measure_usage)
|
async_mem_monitor.save('log.pkl')
|
||||||
|
|
||||||
def finish(self):
|
Args:
|
||||||
if self.keep_measuring is False:
|
power (int, optional): the power of time interva. Defaults to 10.
|
||||||
return 0
|
|
||||||
self.keep_measuring = False
|
.. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
|
||||||
max_usage = self.monitor_thread.result()
|
https://arxiv.org/abs/2108.05818
|
||||||
self.monitor_thread = None
|
"""
|
||||||
self.time_stamps.append(time())
|
|
||||||
self.mem_stats.append(max_usage)
|
def __init__(self, power: int = 10):
|
||||||
return max_usage
|
super().__init__()
|
||||||
|
self.keep_measuring = False
|
||||||
def _measure_usage(self):
|
|
||||||
max_usage = 0
|
current_device = get_current_device()
|
||||||
while self.keep_measuring:
|
|
||||||
max_usage = max(
|
def _set_cuda_device():
|
||||||
max_usage,
|
torch.cuda.set_device(current_device)
|
||||||
colo_device_memory_used(get_current_device()),
|
|
||||||
)
|
self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device)
|
||||||
sleep(self.interval)
|
self.monitor_thread = None
|
||||||
return max_usage
|
self.interval = 1 / (10**power)
|
||||||
|
|
||||||
@property
|
def set_interval(self, power: int):
|
||||||
def state_dict(self):
|
self.clear()
|
||||||
return {
|
self.interval = 1 / (10**power)
|
||||||
"time_stamps": self.time_stamps,
|
|
||||||
"mem_stats": self.mem_stats,
|
def is_measuring(self):
|
||||||
}
|
return self.keep_measuring
|
||||||
|
|
||||||
def save(self, filename):
|
def start(self):
|
||||||
with open(filename, "wb") as f:
|
self.keep_measuring = True
|
||||||
pickle.dump(self.state_dict(), f)
|
self.monitor_thread = self.executor.submit(self._measure_usage)
|
||||||
|
|
||||||
def clear(self):
|
def finish(self):
|
||||||
self.mem_stats.clear()
|
if self.keep_measuring is False:
|
||||||
self.time_stamps.clear()
|
return 0
|
||||||
|
|
||||||
|
self.keep_measuring = False
|
||||||
|
max_usage = self.monitor_thread.result()
|
||||||
|
|
||||||
|
self.monitor_thread = None
|
||||||
|
self.time_stamps.append(time())
|
||||||
|
self.mem_stats.append(max_usage)
|
||||||
|
return max_usage
|
||||||
|
|
||||||
|
def _measure_usage(self):
|
||||||
|
max_usage = 0
|
||||||
|
while self.keep_measuring:
|
||||||
|
max_usage = max(
|
||||||
|
max_usage,
|
||||||
|
colo_device_memory_used(get_current_device()),
|
||||||
|
)
|
||||||
|
sleep(self.interval)
|
||||||
|
return max_usage
|
||||||
|
|
||||||
|
|
||||||
|
class SyncCudaMemoryMonitor(MemoryMonitor):
|
||||||
|
"""
|
||||||
|
A synchronized cuda memory monitor.
|
||||||
|
It only record the maximum allocated cuda memory from start point to finish point.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, power: int = 10):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
self.time_stamps.append(time())
|
||||||
|
max_usage = torch.cuda.max_memory_allocated()
|
||||||
|
self.mem_stats.append(max_usage)
|
||||||
|
return max_usage
|
|
@ -1,6 +1,6 @@
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
from colossalai.utils.memory import colo_device_memory_used
|
from colossalai.utils.memory import colo_device_memory_used
|
||||||
from colossalai.utils.memory_tracer.async_memtracer import AsyncMemoryMonitor
|
from colossalai.utils.memory_tracer import AsyncMemoryMonitor
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
Loading…
Reference in New Issue