diff --git a/colossalai/gemini/memory_tracer/__init__.py b/colossalai/gemini/memory_tracer/__init__.py index d12461353..8bbf1678e 100644 --- a/colossalai/gemini/memory_tracer/__init__.py +++ b/colossalai/gemini/memory_tracer/__init__.py @@ -3,8 +3,9 @@ from .memstats_collector import MemStatsCollector # isort:skip from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER # isort:skip from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip from .static_memstats_collector import StaticMemStatsCollector # isort:skip +from .module_tracer_wrapper import MemtracerWrapper # isort:skip __all__ = [ 'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector', - 'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER' + 'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER', 'MemtracerWrapper' ] diff --git a/colossalai/gemini/memory_tracer/memory_monitor.py b/colossalai/gemini/memory_tracer/memory_monitor.py index 05d03d278..f8d99dbce 100644 --- a/colossalai/gemini/memory_tracer/memory_monitor.py +++ b/colossalai/gemini/memory_tracer/memory_monitor.py @@ -1,142 +1,147 @@ -from abc import abstractmethod -from concurrent.futures import ThreadPoolExecutor -from time import sleep, time -import json - -import torch - -from colossalai.utils import colo_device_memory_used -from colossalai.utils import get_current_device - - -class MemoryMonitor: - """Base class for all types of memory monitor. - All monitors should have a list called `time_stamps` and a list called `mem_stats`. - """ - - def __init__(self): - self.time_stamps = [] - self.mem_stats = [] - - def __len__(self): - return len(self.mem_stats) - - @abstractmethod - def start(self): - pass - - @abstractmethod - def finish(self): - pass - - def state_dict(self): - return { - "time_stamps": self.time_stamps, - "mem_stats": self.mem_stats, - } - - def save(self, filename): - with open(filename, "w") as f: - json.dump(self.state_dict(), f) - - def clear(self): - self.mem_stats.clear() - self.time_stamps.clear() - - -class AsyncMemoryMonitor(MemoryMonitor): - """ - An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU - at interval of `1/(10**power)` sec. - - The idea comes from Runtime Memory Tracer of PatrickStar - `PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_ - - Usage:: - - async_mem_monitor = AsyncMemoryMonitor() - input = torch.randn(2, 20).cuda() - OP1 = torch.nn.Linear(20, 30).cuda() - OP2 = torch.nn.Linear(30, 40).cuda() - - async_mem_monitor.start() - output = OP1(input) - async_mem_monitor.finish() - async_mem_monitor.start() - output = OP2(output) - async_mem_monitor.finish() - async_mem_monitor.save('log.pkl') - - Args: - power (int, optional): the power of time interva. Defaults to 10. - - .. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management: - https://arxiv.org/abs/2108.05818 - """ - - def __init__(self, power: int = 10): - super().__init__() - self.keep_measuring = False - - current_device = get_current_device() - - def _set_cuda_device(): - torch.cuda.set_device(current_device) - - self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device) - self.monitor_thread = None - self.interval = 1 / (10**power) - - def set_interval(self, power: int): - self.clear() - self.interval = 1 / (10**power) - - def is_measuring(self): - return self.keep_measuring - - def start(self): - self.keep_measuring = True - self.monitor_thread = self.executor.submit(self._measure_usage) - - def finish(self): - if self.keep_measuring is False: - return 0 - - self.keep_measuring = False - max_usage = self.monitor_thread.result() - - self.monitor_thread = None - self.time_stamps.append(time()) - self.mem_stats.append(max_usage) - return max_usage - - def _measure_usage(self): - max_usage = 0 - while self.keep_measuring: - max_usage = max( - max_usage, - colo_device_memory_used(get_current_device()), - ) - sleep(self.interval) - return max_usage - - -class SyncCudaMemoryMonitor(MemoryMonitor): - """ - A synchronized cuda memory monitor. - It only record the maximum allocated cuda memory from start point to finish point. - """ - - def __init__(self, power: int = 10): - super().__init__() - - def start(self): - torch.cuda.synchronize() - torch.cuda.reset_peak_memory_stats() - - def finish(self): - torch.cuda.synchronize() - self.time_stamps.append(time()) - max_usage = torch.cuda.max_memory_allocated() - self.mem_stats.append(max_usage) - return max_usage +import json +from abc import abstractmethod +from concurrent.futures import ThreadPoolExecutor +from time import sleep, time + +import torch + +from colossalai.utils import colo_device_memory_used, get_current_device + + +class MemoryMonitor: + """Base class for all types of memory monitor. + All monitors should have a list called `time_stamps` and a list called `mem_stats`. + """ + + def __init__(self): + self.time_stamps = [] + self.mem_stats = [] + + def __len__(self): + return len(self.mem_stats) + + @abstractmethod + def start(self): + pass + + @abstractmethod + def finish(self): + pass + + def state_dict(self): + return { + "time_stamps": self.time_stamps, + "mem_stats": self.mem_stats, + } + + def save(self, filename): + with open(filename, "w") as f: + json.dump(self.state_dict(), f) + + def clear(self): + self.mem_stats.clear() + self.time_stamps.clear() + + +class AsyncMemoryMonitor(MemoryMonitor): + """ + An Async Memory Monitor runing during computing. Sampling memory usage of the current GPU + at interval of `1/(10**power)` sec. + + The idea comes from Runtime Memory Tracer of PatrickStar + `PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management`_ + + Usage:: + + async_mem_monitor = AsyncMemoryMonitor() + input = torch.randn(2, 20).cuda() + OP1 = torch.nn.Linear(20, 30).cuda() + OP2 = torch.nn.Linear(30, 40).cuda() + + async_mem_monitor.start() + output = OP1(input) + async_mem_monitor.finish() + async_mem_monitor.start() + output = OP2(output) + async_mem_monitor.finish() + async_mem_monitor.save('log.pkl') + + Args: + power (int, optional): the power of time interva. Defaults to 10. + + .. _PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management: + https://arxiv.org/abs/2108.05818 + """ + + def __init__(self, power: int = 10): + super().__init__() + self.keep_measuring = False + + current_device = get_current_device() + + def _set_cuda_device(): + torch.cuda.set_device(current_device) + + self.executor = ThreadPoolExecutor(max_workers=1, initializer=_set_cuda_device) + self.monitor_thread = None + self.interval = 1 / (10**power) + + def set_interval(self, power: int): + self.clear() + self.interval = 1 / (10**power) + + def is_measuring(self): + return self.keep_measuring + + def start(self): + self.keep_measuring = True + self.monitor_thread = self.executor.submit(self._measure_usage) + + def finish(self): + if self.keep_measuring is False: + return 0 + + self.keep_measuring = False + max_usage = self.monitor_thread.result() + + self.monitor_thread = None + self.time_stamps.append(time()) + self.mem_stats.append(max_usage) + return max_usage + + def _measure_usage(self): + max_usage = 0 + while self.keep_measuring: + max_usage = max( + max_usage, + colo_device_memory_used(get_current_device()), + ) + sleep(self.interval) + return max_usage + + +class SyncCudaMemoryMonitor(MemoryMonitor): + """ + A synchronized cuda memory monitor. + It only record the maximum allocated cuda memory from start point to finish point. + """ + + def __init__(self, power: int = 10): + super().__init__() + + def start(self): + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + + def finish(self) -> int: + """ + return max gpu memory used since latest `start()`. + + Returns: + int: max GPU memory + """ + torch.cuda.synchronize() + self.time_stamps.append(time()) + max_usage = torch.cuda.max_memory_allocated() + self.mem_stats.append(max_usage) + return max_usage diff --git a/colossalai/gemini/memory_tracer/module_tracer_wrapper.py b/colossalai/gemini/memory_tracer/module_tracer_wrapper.py new file mode 100644 index 000000000..9967df627 --- /dev/null +++ b/colossalai/gemini/memory_tracer/module_tracer_wrapper.py @@ -0,0 +1,36 @@ +from colossalai.gemini.ophooks import register_ophooks_recursively +from colossalai.gemini.ophooks.mem_trace_hook import MemTracerOpHook + +__all__ = ['MemtracerWrapper'] + + +class _Wrapper(): + + def __init__(self, model, ophook_list): + self._ophook_list = ophook_list + self._model = model + + def __call__(self, *args, **kwargs): + return self._model(*args, **kwargs) + + def forward(self, *args, **kwargs): + return self._model.forward(*args, **kwargs) + + def backward(self, loss): + loss.backward() + for ophook in self._ophook_list: + ophook.post_iter() + + def save_results(self, filename): + for ophook in self._ophook_list: + ophook.save_results(filename) + + def show_mem_stats(self): + self._ophook_list[0].show_mem_stats() + + +def MemtracerWrapper(model): + ophook_list = [MemTracerOpHook()] + register_ophooks_recursively(model, ophook_list) + engine = _Wrapper(model, ophook_list) + return engine diff --git a/colossalai/gemini/ophooks/mem_trace_hook.py b/colossalai/gemini/ophooks/mem_trace_hook.py new file mode 100644 index 000000000..efb9b5bfa --- /dev/null +++ b/colossalai/gemini/ophooks/mem_trace_hook.py @@ -0,0 +1,86 @@ +import torch + +from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor +from colossalai.gemini.ophooks import BaseOpHook + + +class MemTracerOpHook(BaseOpHook): + + def __init__(self): + super().__init__() + self.mem_monitor = SyncCudaMemoryMonitor() + self._cur_non_model_data_vol = 0 + self._non_model_data_list = [] + self._cur_model_data_vol = 0 + + def _move_module_to_dev(self, module, dev: str) -> int: + """_move_module_to_dev + move module to cuda + Args: + module (torch.nn.Module): a PyTorch module + dev (torch.device): the target device + Returns: + int: the data volume of this module on the cuda + """ + assert isinstance(dev, str), f"device should be a str not torch.device" + comm_volume = 0 + for p in module.parameters(): + if p.data.device.type != dev: + p.data = p.data.to(dev) + comm_volume += p.data.numel() * p.data.element_size() + if p.grad is not None: + if p.grad.device.type != dev: + p.grad = p.grad.to(dev) + comm_volume += p.grad.numel() * p.grad.element_size() + + if dev == 'cuda': + self._cur_model_data_vol = comm_volume + + return comm_volume + + def pre_fwd_exec(self, module: torch.nn.Module, *args): + if module.training: + cuda_volume = self.mem_monitor.finish() + comm_volume = self._move_module_to_dev(module, 'cuda') + self.mem_monitor.start() + # print(f'FWD PRE {module.__class__.__name__} cuda used {(cuda_volume) / 1e6} MB') + + def post_fwd_exec(self, module: torch.nn.Module, *args): + if module.training: + cuda_volume = self.mem_monitor.finish() + comm_volume = self._move_module_to_dev(module, 'cpu') + # print(f'FWD POST {module.__class__.__name__} cuda used {(cuda_volume) / 1e6} MB, non-model data used {(cuda_volume - comm_volume) / 1e6} MB') + + def pre_bwd_exec(self, module: torch.nn.Module, input, output): + assert isinstance(module, torch.nn.Module) + if module.training: + cuda_volume = self.mem_monitor.finish() + self._move_module_to_dev(module, 'cuda') + self.mem_monitor.start() + # print(f'BWD PRE {module.__class__.__name__}') + + def post_bwd_exec(self, module: torch.nn.Module, input): + # bwd Op will generate grad. comm_volume is grad + data volume on cuda. + assert isinstance(module, torch.nn.Module) + if module.training: + cuda_volume = self.mem_monitor.finish() + comm_volume = self._move_module_to_dev(module, 'cpu') + # print(f'BWD POST {module.__class__.__name__} {cuda_volume / 1e6} MB, non-model data used {(cuda_volume - comm_volume) / 1e6} MB') + + def pre_iter(self): + pass + + def post_iter(self): + self.mem_monitor.finish() + # print(f'post_iter') + + def save_results(self, filename): + self.mem_monitor.save(filename) + + def show_mem_stats(self): + start_timestamp = min(self.mem_monitor.time_stamps) + self.mem_monitor.time_stamps = [elem - start_timestamp for elem in self.mem_monitor.time_stamps] + min_mem_used = min(self.mem_monitor.mem_stats) + self.mem_monitor.mem_stats = [elem - min_mem_used for elem in self.mem_monitor.mem_stats] + print(self.mem_monitor.time_stamps) + print(self.mem_monitor.mem_stats)