You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/gemini/memory_tracer/memory_monitor.py

143 lines
3.9 KiB

from abc import abstractmethod
from concurrent.futures import ThreadPoolExecutor
from time import sleep, time
import json
import torch
from colossalai.utils import colo_device_memory_used
from colossalai.utils import get_current_device
class MemoryMonitor:
"""Base class for all types of memory monitor.
All monitors should have a list called `time_stamps` and a list called `mem_stats`.
"""
def __init__(self):
self.time_stamps = []
self.mem_stats = []
def __len__(self):
return len(self.mem_stats)
@abstractmethod
def start(self):
pass
@abstractmethod
def finish(self):
pass
def state_dict(self):
return {
"time_stamps": self.time_stamps,
"mem_stats": self.mem_stats,
}
def save(self, filename):
with open(filename, "w") as f:
json.dump(self.state_dict(), f)
def clear(self):
self.mem_stats.clear()
self.time_stamps.clear()
class AsyncMemoryMonitor(MemoryMonitor):
"""
An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU
at interval of `1/(10**power)` sec.
The idea comes from Runtime Memory Tracer of PatrickStar
`PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_
Usage::
async_mem_monitor = AsyncMemoryMonitor()
input = torch.randn(2, 20).cuda()
OP1 = torch.nn.Linear(20, 30).cuda()
OP2 = torch.nn.Linear(30, 40).cuda()
async_mem_monitor.start()
output = OP1(input)
async_mem_monitor.finish()
async_mem_monitor.start()
output = OP2(output)
async_mem_monitor.finish()
async_mem_monitor.save('log.pkl')
Args:
power (int, optional): the power of time interva. Defaults to 10.
.. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management:
https://arxiv.org/abs/2108.05818
"""
def __init__(self, power: int = 10):
super().__init__()
self.keep_measuring = False
current_device = get_current_device()
def _set_cuda_device():
torch.cuda.set_device(current_device)
self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device)
self.monitor_thread = None
self.interval = 1 / (10**power)
def set_interval(self, power: int):
self.clear()
self.interval = 1 / (10**power)
def is_measuring(self):
return self.keep_measuring
def start(self):
self.keep_measuring = True
self.monitor_thread = self.executor.submit(self._measure_usage)
def finish(self):
if self.keep_measuring is False:
return 0
self.keep_measuring = False
max_usage = self.monitor_thread.result()
self.monitor_thread = None
self.time_stamps.append(time())
self.mem_stats.append(max_usage)
return max_usage
def _measure_usage(self):
max_usage = 0
while self.keep_measuring:
max_usage = max(
max_usage,
colo_device_memory_used(get_current_device()),
)
sleep(self.interval)
return max_usage
class SyncCudaMemoryMonitor(MemoryMonitor):
"""
A synchronized cuda memory monitor.
It only record the maximum allocated cuda memory from start point to finish point.
"""
def __init__(self, power: int = 10):
super().__init__()
def start(self):
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
def finish(self):
torch.cuda.synchronize()
self.time_stamps.append(time())
max_usage = torch.cuda.max_memory_allocated()
self.mem_stats.append(max_usage)
return max_usage