diff --git a/colossalai/gemini/memory_tracer/__init__.py b/colossalai/gemini/memory_tracer/__init__.py index c7b7efad7..12f6b7950 100644 --- a/colossalai/gemini/memory_tracer/__init__.py +++ b/colossalai/gemini/memory_tracer/__init__.py @@ -1,3 +1,4 @@ +from .param_runtime_order import ParamRuntimeOrder # isort:skip from .memory_stats import MemStats # isort:skip from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip from .memstats_collector import MemStatsCollector # isort:skip @@ -6,5 +7,5 @@ from .static_memstats_collector import StaticMemStatsCollector # isort:skip __all__ = [ 'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector', - 'StaticMemStatsCollector', 'MemStats' + 'StaticMemStatsCollector', 'MemStats', 'ParamRuntimeOrder' ] diff --git a/colossalai/gemini/memory_tracer/memory_stats.py b/colossalai/gemini/memory_tracer/memory_stats.py index 496ec7c18..4412a580e 100644 --- a/colossalai/gemini/memory_tracer/memory_stats.py +++ b/colossalai/gemini/memory_tracer/memory_stats.py @@ -1,5 +1,7 @@ from typing import Any, Dict, List +from colossalai.gemini.memory_tracer import ParamRuntimeOrder + class MemStats(object): @@ -19,6 +21,8 @@ class MemStats(object): self._non_model_data_cuda_list = [] self._non_model_data_cpu_list = [] + self._param_runtime_order = ParamRuntimeOrder() + def append_overall_data(self, device_type: str, val: float): if device_type == 'cuda': self._overall_cuda_list.append(val) @@ -112,3 +116,5 @@ class MemStats(object): self._non_model_data_cpu_list = [] self._non_model_data_cuda_list = [] + + self._param_runtime_order.clear() diff --git a/colossalai/gemini/memory_tracer/param_runtime_order.py b/colossalai/gemini/memory_tracer/param_runtime_order.py new file mode 100644 index 000000000..ceb13bc24 --- /dev/null +++ b/colossalai/gemini/memory_tracer/param_runtime_order.py @@ -0,0 +1,25 @@ +import torch + + +class ParamRuntimeOrder(object): + """ParamRuntimeOrder + + Contain the order of parameters visited during runtime. + """ + + def __init__(self) -> None: + self.param_visited_order = [] + + def append(self, param: torch.nn.Parameter): + self.param_visited_order.append(param) + + def generate(self): + visited_set = set() + for p in self.param_visited_order: + if p not in visited_set: + yield p + visited_set.add(p) + del visited_set + + def clear(self): + self.param_visited_order = [] diff --git a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py index 1090cf92c..4eacb49d0 100644 --- a/colossalai/gemini/memory_tracer/runtime_mem_tracer.py +++ b/colossalai/gemini/memory_tracer/runtime_mem_tracer.py @@ -1,6 +1,6 @@ import torch.nn -from colossalai.gemini.memory_tracer import MemStats +from colossalai.gemini.memory_tracer import MemStats, ParamRuntimeOrder from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemStats, GradMemTracerHook, ParamMemTracerHook from colossalai.nn.parallel.data_parallel import _cast_float from colossalai.tensor.param_op_hook import ColoParamOpHookManager @@ -35,6 +35,9 @@ class RuntimeMemTracer(): self._cast_buffers_to_cuda_dtype() + def parameters_in_runtime_order(self): + return self._memstats._param_runtime_order.generate() + def memstats(self): return self._memstats diff --git a/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py b/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py index 6430a471e..faba1e22a 100644 --- a/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py +++ b/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py @@ -99,6 +99,10 @@ class ParamMemTracerHook(ColoParamOpHook): self.sample_model_data(params) self.mem_monitor.start() + # register the order of visited. + for p in params: + self._memstats._param_runtime_order.append(p) + def post_op(self, params): self._free_cuda_params(params) diff --git a/tests/test_gemini/test_runtime_mem_tracer.py b/tests/test_gemini/test_runtime_mem_tracer.py index 34c200e05..294868458 100644 --- a/tests/test_gemini/test_runtime_mem_tracer.py +++ b/tests/test_gemini/test_runtime_mem_tracer.py @@ -38,6 +38,13 @@ def test_runtime_mem_tracer(): print("cuda_non_model_data_list", len(cuda_non_model_data_list)) print(non_model_data_list) + cnt1 = 0 + for p in runtime_mem_tracer.parameters_in_runtime_order(): + cnt1 += 1 + cnt2 = 0 + for p in model.parameters(): + cnt2 += 1 + assert cnt2 == cnt1, f'visited param number {cnt1} vs real param number {cnt2}' del model