From 10e282642684675aee9bbd3e6847b9ac53a4e51c Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Wed, 9 Mar 2022 16:31:25 +0800 Subject: [PATCH] move async memory to an individual directory (#345) --- .../engine/ophooks/_memtracer_ophook.py | 94 +-------------- colossalai/utils/memory_tracer/__init__.py | 3 + .../utils/memory_tracer/async_memtracer.py | 108 ++++++++++++++++++ .../memory_tracer/test_async_memtracer.py | 16 +++ 4 files changed, 131 insertions(+), 90 deletions(-) create mode 100644 colossalai/utils/memory_tracer/__init__.py create mode 100644 colossalai/utils/memory_tracer/async_memtracer.py create mode 100644 colossalai/utils/memory_tracer/test_async_memtracer.py diff --git a/colossalai/engine/ophooks/_memtracer_ophook.py b/colossalai/engine/ophooks/_memtracer_ophook.py index 3ba8e536d..e77c41055 100644 --- a/colossalai/engine/ophooks/_memtracer_ophook.py +++ b/colossalai/engine/ophooks/_memtracer_ophook.py @@ -1,101 +1,15 @@ from colossalai.context.parallel_mode import ParallelMode import torch -from . import BaseOpHook -from concurrent.futures import ThreadPoolExecutor +from colossalai.engine.ophooks import BaseOpHook from colossalai.registry import OPHOOKS from colossalai.logging import get_dist_logger -from time import sleep, time -import pickle -from typing import Optional from colossalai.core import global_context as gpc + +from colossalai.utils.memory_tracer import AsyncMemoryMonitor + import math -def get_cuda_memory_used(device: Optional[torch.device]) -> int: - """Get the free memory info of device. - Notice that for CPU, this function will return 1/N of the total free memory, - where N is the world size. - - :param device: device id - :type device: torch.device - :return: current memory usage, sized by MB - :rtype: int - """ - ret: int = torch.cuda.memory_allocated(device) - # get the peak memory to report correct data, so reset the counter for the next call - if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ - torch.cuda.reset_peak_memory_stats(device) - return ret - - -class AsyncMemoryMonitor: - """ - An Async Mem Monitor runing during computing. Sampling GPU memory usage of the current GPU - at interval of 1/(10**power) sec. - - :param power: the power of time interval, defaults to 10 - :type power: int - """ - - def __init__(self, power: int = 10): - - self.keep_measuring = False - self.executor = ThreadPoolExecutor(max_workers=1) - self.monitor_thread = None - self.interval = 1 / (10**power) - self.time_stamps = [] - self.mem_stats = [] - - def __len__(self): - return len(self.mem_stats) - - def set_interval(self, power: int): - self.clear() - self.interval = 1 / (10**power) - - def is_measuring(self): - return self.keep_measuring - - def start(self): - self.keep_measuring = True - self.monitor_thread = self.executor.submit(self._measure_usage) - - def finish(self): - if self.keep_measuring is False: - return 0 - self.keep_measuring = False - max_usage = self.monitor_thread.result() - self.monitor_thread = None - self.time_stamps.append(time()) - self.mem_stats.append(max_usage) - return max_usage - - def _measure_usage(self): - max_usage = 0 - dev = torch.device(f"cuda:{torch.cuda.current_device()}") - while self.keep_measuring: - max_usage = max( - max_usage, - get_cuda_memory_used(dev), - ) - sleep(self.interval) - return max_usage - - def state_dict(self): - return { - "time_stamps": self.time_stamps, - "mem_stats": self.mem_stats, - } - - def save(self, filename): - with open(filename, "wb") as f: - pickle.dump(self.state_dict(), f) - - def clear(self): - self.mem_stats.clear() - self.time_stamps.clear() - - @OPHOOKS.register_module class MemTracerOpHook(BaseOpHook): """ diff --git a/colossalai/utils/memory_tracer/__init__.py b/colossalai/utils/memory_tracer/__init__.py new file mode 100644 index 000000000..f40430d38 --- /dev/null +++ b/colossalai/utils/memory_tracer/__init__.py @@ -0,0 +1,3 @@ +from .async_memtracer import AsyncMemoryMonitor + +__all__ = ['AsyncMemoryMonitor'] diff --git a/colossalai/utils/memory_tracer/async_memtracer.py b/colossalai/utils/memory_tracer/async_memtracer.py new file mode 100644 index 000000000..8f968acfb --- /dev/null +++ b/colossalai/utils/memory_tracer/async_memtracer.py @@ -0,0 +1,108 @@ +from concurrent.futures import ThreadPoolExecutor +from time import sleep, time +import pickle + +from colossalai.utils import get_current_device +import torch + + +def _get_cuda_memory_used(device: torch.device) -> int: + """ + Get the free memory info of device. + :param device: device id + :type device: torch.device + :return: current memory usage, sized by MB + :rtype: int + """ + + assert device.type == 'cuda' + + ret: int = torch.cuda.memory_allocated(device) + # get the peak memory to report correct data, so reset the counter for the next call + if hasattr(torch.cuda, "reset_peak_memory_stats"): # pytorch 1.4+ + torch.cuda.reset_peak_memory_stats(device) + return ret + + +class AsyncMemoryMonitor: + """ + An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU + at interval of 1/(10**power) sec. + + :param power: the power of time interval, defaults to 10 + :type power: int + + Usage: + + ```python + async_mem_monitor = AsyncMemoryMonitor() + input = torch.randn(2, 20).cuda() + OP1 = torch.nn.Linear(20, 30).cuda() + OP2 = torch.nn.Linear(30, 40).cuda() + + async_mem_monitor.start() + output = OP1(input) + async_mem_monitor.finish() + async_mem_monitor.start() + output = OP2(output) + async_mem_monitor.finish() + async_mem_monitor.save('log.pkl') + ``` + """ + + def __init__(self, power: int = 10): + self.keep_measuring = False + self.executor = ThreadPoolExecutor(max_workers=1) + self.monitor_thread = None + self.interval = 1 / (10**power) + self.time_stamps = [] + self.mem_stats = [] + + def __len__(self): + return len(self.mem_stats) + + def set_interval(self, power: int): + self.clear() + self.interval = 1 / (10**power) + + def is_measuring(self): + return self.keep_measuring + + def start(self): + self.keep_measuring = True + self.monitor_thread = self.executor.submit(self._measure_usage) + + def finish(self): + if self.keep_measuring is False: + return 0 + self.keep_measuring = False + max_usage = self.monitor_thread.result() + self.monitor_thread = None + self.time_stamps.append(time()) + self.mem_stats.append(max_usage) + return max_usage + + def _measure_usage(self): + max_usage = 0 + while self.keep_measuring: + max_usage = max( + max_usage, + _get_cuda_memory_used(torch.device(f'cuda:{get_current_device()}')), + ) + sleep(self.interval) + return max_usage + + def state_dict(self): + return { + "time_stamps": self.time_stamps, + "mem_stats": self.mem_stats, + } + + def save(self, filename): + with open(filename, "wb") as f: + print(self.state_dict()) + pickle.dump(self.state_dict(), f) + + def clear(self): + self.mem_stats.clear() + self.time_stamps.clear() diff --git a/colossalai/utils/memory_tracer/test_async_memtracer.py b/colossalai/utils/memory_tracer/test_async_memtracer.py new file mode 100644 index 000000000..06c4052bd --- /dev/null +++ b/colossalai/utils/memory_tracer/test_async_memtracer.py @@ -0,0 +1,16 @@ +from async_memtracer import AsyncMemoryMonitor +import torch + +if __name__ == '__main__': + async_mem_monitor = AsyncMemoryMonitor() + input = torch.randn(2, 20).cuda() + OP1 = torch.nn.Linear(20, 30).cuda() + OP2 = torch.nn.Linear(30, 40).cuda() + + async_mem_monitor.start() + output = OP1(input) + async_mem_monitor.finish() + async_mem_monitor.start() + output = OP2(output) + async_mem_monitor.finish() + async_mem_monitor.save('log.pkl')