mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] use MemStats in Runtime Memory tracer (#2088)
parent
33f4412102
commit
25abae6d7f
|
@ -3,8 +3,9 @@ from .memstats_collector import MemStatsCollector # isort:skip
|
|||
from .model_data_memtracer import GLOBAL_MODEL_DATA_TRACER # isort:skip
|
||||
from .chunk_memstats_collector import ChunkMemStatsCollector # isort:skip
|
||||
from .static_memstats_collector import StaticMemStatsCollector # isort:skip
|
||||
from .memory_stats import MemStats
|
||||
|
||||
__all__ = [
|
||||
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
|
||||
'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER'
|
||||
'StaticMemStatsCollector', 'GLOBAL_MODEL_DATA_TRACER', 'MemStats'
|
||||
]
|
||||
|
|
|
@ -36,6 +36,8 @@ class MemStats(object):
|
|||
raise TypeError
|
||||
|
||||
def append_non_model_data(self, device_type: str):
|
||||
if len(self._overall_cuda_list) == 0 or len(self._model_data_cuda_list) == 0:
|
||||
return
|
||||
if device_type == 'cuda':
|
||||
self._non_model_data_cuda_list.append(self._overall_cuda_list[-1] - self._model_data_cuda_list[-1])
|
||||
elif device_type == 'cpu':
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import torch.nn
|
||||
|
||||
from colossalai.gemini.memory_tracer import MemStats
|
||||
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
|
||||
from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemTracerHook, ParamMemTracerHook
|
||||
from colossalai.nn.parallel.data_parallel import _cast_float
|
||||
|
@ -24,7 +25,8 @@ class RuntimeMemTracer():
|
|||
super().__init__()
|
||||
self.module = module
|
||||
self.dtype = dtype
|
||||
self.param_op_hook = ParamMemTracerHook()
|
||||
self._memstats = MemStats()
|
||||
self.param_op_hook = ParamMemTracerHook(self._memstats)
|
||||
self.grad_hook = GradMemTracerHook(module)
|
||||
self.cpu_param_data_dict = {}
|
||||
|
||||
|
@ -74,14 +76,17 @@ class RuntimeMemTracer():
|
|||
|
||||
def _post_backward(self):
|
||||
cuda_volume = self.param_op_hook.mem_monitor.finish()
|
||||
last_model_data = GLOBAL_CUDA_MEM_INFO.model_data_list[-1]
|
||||
GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - last_model_data)
|
||||
self._memstats.append_model_data('cuda', cuda_volume)
|
||||
self._memstats.append_non_model_data('cuda')
|
||||
# last_model_data = GLOBAL_CUDA_MEM_INFO.model_data_list[-1]
|
||||
# GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - last_model_data)
|
||||
self.grad_hook.remove_grad_hook()
|
||||
self._restore_params()
|
||||
|
||||
def _clear_cuda_mem_info(self):
|
||||
GLOBAL_CUDA_MEM_INFO.model_data_list.clear()
|
||||
GLOBAL_CUDA_MEM_INFO.non_model_data_list.clear()
|
||||
# GLOBAL_CUDA_MEM_INFO.model_data_list.clear()
|
||||
# GLOBAL_CUDA_MEM_INFO.non_model_data_list.clear()
|
||||
self._memstats.clear()
|
||||
GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag.clear()
|
||||
GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume = 0
|
||||
|
||||
|
|
|
@ -41,9 +41,10 @@ class GradMemTracerHook():
|
|||
|
||||
class ParamMemTracerHook(ColoParamOpHook):
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, memstats) -> None:
|
||||
super().__init__()
|
||||
self._training_phase = TrainingPhase.FORWARD
|
||||
self._memstats = memstats
|
||||
self.mem_monitor = SyncCudaMemoryMonitor()
|
||||
|
||||
def _free_cuda_params(self, params):
|
||||
|
@ -76,12 +77,14 @@ class ParamMemTracerHook(ColoParamOpHook):
|
|||
if not GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p]:
|
||||
GLOBAL_CUDA_MEM_INFO.unreleased_grad_volume += cur_model_data_volume
|
||||
GLOBAL_CUDA_MEM_INFO.unreleased_grad_flag[p] = True
|
||||
GLOBAL_CUDA_MEM_INFO.model_data_list.append(data_volume)
|
||||
# GLOBAL_CUDA_MEM_INFO.model_data_list.append(data_volume)
|
||||
self._memstats.append_model_data('cuda', data_volume)
|
||||
|
||||
def pre_op(self, params):
|
||||
cuda_volume = self.mem_monitor.finish()
|
||||
if len(GLOBAL_CUDA_MEM_INFO.model_data_list):
|
||||
GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - GLOBAL_CUDA_MEM_INFO.model_data_list[-1])
|
||||
self._memstats.append_model_data('cuda', cuda_volume)
|
||||
# if len(GLOBAL_CUDA_MEM_INFO.model_data_list):
|
||||
# GLOBAL_CUDA_MEM_INFO.non_model_data_list.append(cuda_volume - GLOBAL_CUDA_MEM_INFO.model_data_list[-1])
|
||||
self._allocate_params_on_cuda(params)
|
||||
self.sample_model_data(params)
|
||||
self.mem_monitor.start()
|
||||
|
|
|
@ -3,7 +3,6 @@ from copy import deepcopy
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from colossalai.gemini.memory_tracer.model_data_memtracer import GLOBAL_CUDA_MEM_INFO
|
||||
from colossalai.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from tests.components_to_test import run_fwd_bwd
|
||||
|
@ -34,9 +33,10 @@ def test_runtime_mem_tracer():
|
|||
for p1, p2 in zip(model_bk.parameters(), model.parameters()):
|
||||
torch.allclose(p1.to(torch.half), p2)
|
||||
|
||||
cuda_non_model_data_list = np.array(GLOBAL_CUDA_MEM_INFO.non_model_data_list) / 1024**2
|
||||
non_model_data_list = runtime_mem_tracer._memstats.non_model_data_list('cuda')
|
||||
cuda_non_model_data_list = np.array(non_model_data_list) / 1024**2
|
||||
print("cuda_non_model_data_list", len(cuda_non_model_data_list))
|
||||
print(GLOBAL_CUDA_MEM_INFO.non_model_data_list)
|
||||
print(non_model_data_list)
|
||||
|
||||
del model
|
||||
|
||||
|
|
Loading…
Reference in New Issue