mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] independent runtime tracer (#1974)
parent
0da1d00399
commit
0529fcde06
|
@ -3,8 +3,9 @@ from .memstats_collector import MemStatsCollector # isort:skip
|
|||
from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER # isort:skip
|
||||
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
|
||||
from .static_memstats_collector import StaticMemStatsCollector # isort:skip
|
||||
from .module_tracer_wrapper import MemtracerWrapper # isort:skip
|
||||
|
||||
__all__ = [
|
||||
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
|
||||
'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER'
|
||||
'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER', 'MemtracerWrapper'
|
||||
]
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
import json
|
||||
from abc import abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from time import sleep, time
|
||||
import json
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.utils import colo_device_memory_used
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils import colo_device_memory_used, get_current_device
|
||||
|
||||
|
||||
class MemoryMonitor:
|
||||
|
@ -134,7 +133,13 @@ class SyncCudaMemoryMonitor(MemoryMonitor):
|
|||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
def finish(self):
|
||||
def finish(self) -> int:
|
||||
"""
|
||||
return max gpu memory used since latest `start()`.
|
||||
|
||||
Returns:
|
||||
int: max GPU memory
|
||||
"""
|
||||
torch.cuda.synchronize()
|
||||
self.time_stamps.append(time())
|
||||
max_usage = torch.cuda.max_memory_allocated()
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
from colossalai.gemini.ophooks import register_ophooks_recursively
|
||||
from colossalai.gemini.ophooks.mem_trace_hook import MemTracerOpHook
|
||||
|
||||
__all__ = ['MemtracerWrapper']
|
||||
|
||||
|
||||
class _Wrapper():
|
||||
|
||||
def __init__(self, model, ophook_list):
|
||||
self._ophook_list = ophook_list
|
||||
self._model = model
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self._model(*args, **kwargs)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self._model.forward(*args, **kwargs)
|
||||
|
||||
def backward(self, loss):
|
||||
loss.backward()
|
||||
for ophook in self._ophook_list:
|
||||
ophook.post_iter()
|
||||
|
||||
def save_results(self, filename):
|
||||
for ophook in self._ophook_list:
|
||||
ophook.save_results(filename)
|
||||
|
||||
def show_mem_stats(self):
|
||||
self._ophook_list[0].show_mem_stats()
|
||||
|
||||
|
||||
def MemtracerWrapper(model):
|
||||
ophook_list = [MemTracerOpHook()]
|
||||
register_ophooks_recursively(model, ophook_list)
|
||||
engine = _Wrapper(model, ophook_list)
|
||||
return engine
|
|
@ -0,0 +1,86 @@
|
|||
import torch
|
||||
|
||||
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
||||
from colossalai.gemini.ophooks import BaseOpHook
|
||||
|
||||
|
||||
class MemTracerOpHook(BaseOpHook):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mem_monitor = SyncCudaMemoryMonitor()
|
||||
self._cur_non_model_data_vol = 0
|
||||
self._non_model_data_list = []
|
||||
self._cur_model_data_vol = 0
|
||||
|
||||
def _move_module_to_dev(self, module, dev: str) -> int:
|
||||
"""_move_module_to_dev
|
||||
move module to cuda
|
||||
Args:
|
||||
module (torch.nn.Module): a PyTorch module
|
||||
dev (torch.device): the target device
|
||||
Returns:
|
||||
int: the data volume of this module on the cuda
|
||||
"""
|
||||
assert isinstance(dev, str), f"device should be a str not torch.device"
|
||||
comm_volume = 0
|
||||
for p in module.parameters():
|
||||
if p.data.device.type != dev:
|
||||
p.data = p.data.to(dev)
|
||||
comm_volume += p.data.numel() * p.data.element_size()
|
||||
if p.grad is not None:
|
||||
if p.grad.device.type != dev:
|
||||
p.grad = p.grad.to(dev)
|
||||
comm_volume += p.grad.numel() * p.grad.element_size()
|
||||
|
||||
if dev == 'cuda':
|
||||
self._cur_model_data_vol = comm_volume
|
||||
|
||||
return comm_volume
|
||||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
if module.training:
|
||||
cuda_volume = self.mem_monitor.finish()
|
||||
comm_volume = self._move_module_to_dev(module, 'cuda')
|
||||
self.mem_monitor.start()
|
||||
# print(f'FWD PRE {module.__class__.__name__} cuda used {(cuda_volume) / 1e6} MB')
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
if module.training:
|
||||
cuda_volume = self.mem_monitor.finish()
|
||||
comm_volume = self._move_module_to_dev(module, 'cpu')
|
||||
# print(f'FWD POST {module.__class__.__name__} cuda used {(cuda_volume) / 1e6} MB, non-model data used {(cuda_volume - comm_volume) / 1e6} MB')
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
if module.training:
|
||||
cuda_volume = self.mem_monitor.finish()
|
||||
self._move_module_to_dev(module, 'cuda')
|
||||
self.mem_monitor.start()
|
||||
# print(f'BWD PRE {module.__class__.__name__}')
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
# bwd Op will generate grad. comm_volume is grad + data volume on cuda.
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
if module.training:
|
||||
cuda_volume = self.mem_monitor.finish()
|
||||
comm_volume = self._move_module_to_dev(module, 'cpu')
|
||||
# print(f'BWD POST {module.__class__.__name__} {cuda_volume / 1e6} MB, non-model data used {(cuda_volume - comm_volume) / 1e6} MB')
|
||||
|
||||
def pre_iter(self):
|
||||
pass
|
||||
|
||||
def post_iter(self):
|
||||
self.mem_monitor.finish()
|
||||
# print(f'post_iter')
|
||||
|
||||
def save_results(self, filename):
|
||||
self.mem_monitor.save(filename)
|
||||
|
||||
def show_mem_stats(self):
|
||||
start_timestamp = min(self.mem_monitor.time_stamps)
|
||||
self.mem_monitor.time_stamps = [elem - start_timestamp for elem in self.mem_monitor.time_stamps]
|
||||
min_mem_used = min(self.mem_monitor.mem_stats)
|
||||
self.mem_monitor.mem_stats = [elem - min_mem_used for elem in self.mem_monitor.mem_stats]
|
||||
print(self.mem_monitor.time_stamps)
|
||||
print(self.mem_monitor.mem_stats)
|
Loading…
Reference in New Issue