mirror of https://github.com/hpcaitech/ColossalAI
[gemini] get the param visited order during runtime (#2108)
parent
61f31c3cf0
commit
70a8556946
|
@ -1,3 +1,4 @@
|
|||
from .param_runtime_order import ParamRuntimeOrder # isort:skip
|
||||
from .memory_stats import MemStats # isort:skip
|
||||
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip
|
||||
from .memstats_collector import MemStatsCollector # isort:skip
|
||||
|
@ -6,5 +7,5 @@ from .static_memstats_collector import StaticMemStatsCollector # isort:skip
|
|||
|
||||
__all__ = [
|
||||
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
|
||||
'StaticMemStatsCollector', 'MemStats'
|
||||
'StaticMemStatsCollector', 'MemStats', 'ParamRuntimeOrder'
|
||||
]
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
from typing import Any, Dict, List
|
||||
|
||||
from colossalai.gemini.memory_tracer import ParamRuntimeOrder
|
||||
|
||||
|
||||
class MemStats(object):
|
||||
|
||||
|
@ -19,6 +21,8 @@ class MemStats(object):
|
|||
self._non_model_data_cuda_list = []
|
||||
self._non_model_data_cpu_list = []
|
||||
|
||||
self._param_runtime_order = ParamRuntimeOrder()
|
||||
|
||||
def append_overall_data(self, device_type: str, val: float):
|
||||
if device_type == 'cuda':
|
||||
self._overall_cuda_list.append(val)
|
||||
|
@ -112,3 +116,5 @@ class MemStats(object):
|
|||
|
||||
self._non_model_data_cpu_list = []
|
||||
self._non_model_data_cuda_list = []
|
||||
|
||||
self._param_runtime_order.clear()
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
import torch
|
||||
|
||||
|
||||
class ParamRuntimeOrder(object):
|
||||
"""ParamRuntimeOrder
|
||||
|
||||
Contain the order of parameters visited during runtime.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.param_visited_order = []
|
||||
|
||||
def append(self, param: torch.nn.Parameter):
|
||||
self.param_visited_order.append(param)
|
||||
|
||||
def generate(self):
|
||||
visited_set = set()
|
||||
for p in self.param_visited_order:
|
||||
if p not in visited_set:
|
||||
yield p
|
||||
visited_set.add(p)
|
||||
del visited_set
|
||||
|
||||
def clear(self):
|
||||
self.param_visited_order = []
|
|
@ -1,6 +1,6 @@
|
|||
import torch.nn
|
||||
|
||||
from colossalai.gemini.memory_tracer import MemStats
|
||||
from colossalai.gemini.memory_tracer import MemStats, ParamRuntimeOrder
|
||||
from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemStats, GradMemTracerHook, ParamMemTracerHook
|
||||
from colossalai.nn.parallel.data_parallel import _cast_float
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
|
@ -35,6 +35,9 @@ class RuntimeMemTracer():
|
|||
|
||||
self._cast_buffers_to_cuda_dtype()
|
||||
|
||||
def parameters_in_runtime_order(self):
|
||||
return self._memstats._param_runtime_order.generate()
|
||||
|
||||
def memstats(self):
|
||||
return self._memstats
|
||||
|
||||
|
|
|
@ -99,6 +99,10 @@ class ParamMemTracerHook(ColoParamOpHook):
|
|||
self.sample_model_data(params)
|
||||
self.mem_monitor.start()
|
||||
|
||||
# register the order of visited.
|
||||
for p in params:
|
||||
self._memstats._param_runtime_order.append(p)
|
||||
|
||||
def post_op(self, params):
|
||||
self._free_cuda_params(params)
|
||||
|
||||
|
|
|
@ -38,6 +38,13 @@ def test_runtime_mem_tracer():
|
|||
print("cuda_non_model_data_list", len(cuda_non_model_data_list))
|
||||
print(non_model_data_list)
|
||||
|
||||
cnt1 = 0
|
||||
for p in runtime_mem_tracer.parameters_in_runtime_order():
|
||||
cnt1 += 1
|
||||
cnt2 = 0
|
||||
for p in model.parameters():
|
||||
cnt2 += 1
|
||||
assert cnt2 == cnt1, f'visited param number {cnt1} vs real param number {cnt2}'
|
||||
del model
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue