mirror of https://github.com/hpcaitech/ColossalAI
[gemini] get the param visited order during runtime (#2108)
parent
61f31c3cf0
commit
70a8556946
|
@ -1,3 +1,4 @@
|
||||||
|
from .param_runtime_order import ParamRuntimeOrder # isort:skip
|
||||||
from .memory_stats import MemStats # isort:skip
|
from .memory_stats import MemStats # isort:skip
|
||||||
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip
|
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip
|
||||||
from .memstats_collector import MemStatsCollector # isort:skip
|
from .memstats_collector import MemStatsCollector # isort:skip
|
||||||
|
@ -6,5 +7,5 @@ from .static_memstats_collector import StaticMemStatsCollector # isort:skip
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
|
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
|
||||||
'StaticMemStatsCollector', 'MemStats'
|
'StaticMemStatsCollector', 'MemStats', 'ParamRuntimeOrder'
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
from colossalai.gemini.memory_tracer import ParamRuntimeOrder
|
||||||
|
|
||||||
|
|
||||||
class MemStats(object):
|
class MemStats(object):
|
||||||
|
|
||||||
|
@ -19,6 +21,8 @@ class MemStats(object):
|
||||||
self._non_model_data_cuda_list = []
|
self._non_model_data_cuda_list = []
|
||||||
self._non_model_data_cpu_list = []
|
self._non_model_data_cpu_list = []
|
||||||
|
|
||||||
|
self._param_runtime_order = ParamRuntimeOrder()
|
||||||
|
|
||||||
def append_overall_data(self, device_type: str, val: float):
|
def append_overall_data(self, device_type: str, val: float):
|
||||||
if device_type == 'cuda':
|
if device_type == 'cuda':
|
||||||
self._overall_cuda_list.append(val)
|
self._overall_cuda_list.append(val)
|
||||||
|
@ -112,3 +116,5 @@ class MemStats(object):
|
||||||
|
|
||||||
self._non_model_data_cpu_list = []
|
self._non_model_data_cpu_list = []
|
||||||
self._non_model_data_cuda_list = []
|
self._non_model_data_cuda_list = []
|
||||||
|
|
||||||
|
self._param_runtime_order.clear()
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class ParamRuntimeOrder(object):
|
||||||
|
"""ParamRuntimeOrder
|
||||||
|
|
||||||
|
Contain the order of parameters visited during runtime.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.param_visited_order = []
|
||||||
|
|
||||||
|
def append(self, param: torch.nn.Parameter):
|
||||||
|
self.param_visited_order.append(param)
|
||||||
|
|
||||||
|
def generate(self):
|
||||||
|
visited_set = set()
|
||||||
|
for p in self.param_visited_order:
|
||||||
|
if p not in visited_set:
|
||||||
|
yield p
|
||||||
|
visited_set.add(p)
|
||||||
|
del visited_set
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.param_visited_order = []
|
|
@ -1,6 +1,6 @@
|
||||||
import torch.nn
|
import torch.nn
|
||||||
|
|
||||||
from colossalai.gemini.memory_tracer import MemStats
|
from colossalai.gemini.memory_tracer import MemStats, ParamRuntimeOrder
|
||||||
from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemStats, GradMemTracerHook, ParamMemTracerHook
|
from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemStats, GradMemTracerHook, ParamMemTracerHook
|
||||||
from colossalai.nn.parallel.data_parallel import _cast_float
|
from colossalai.nn.parallel.data_parallel import _cast_float
|
||||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||||
|
@ -35,6 +35,9 @@ class RuntimeMemTracer():
|
||||||
|
|
||||||
self._cast_buffers_to_cuda_dtype()
|
self._cast_buffers_to_cuda_dtype()
|
||||||
|
|
||||||
|
def parameters_in_runtime_order(self):
|
||||||
|
return self._memstats._param_runtime_order.generate()
|
||||||
|
|
||||||
def memstats(self):
|
def memstats(self):
|
||||||
return self._memstats
|
return self._memstats
|
||||||
|
|
||||||
|
|
|
@ -99,6 +99,10 @@ class ParamMemTracerHook(ColoParamOpHook):
|
||||||
self.sample_model_data(params)
|
self.sample_model_data(params)
|
||||||
self.mem_monitor.start()
|
self.mem_monitor.start()
|
||||||
|
|
||||||
|
# register the order of visited.
|
||||||
|
for p in params:
|
||||||
|
self._memstats._param_runtime_order.append(p)
|
||||||
|
|
||||||
def post_op(self, params):
|
def post_op(self, params):
|
||||||
self._free_cuda_params(params)
|
self._free_cuda_params(params)
|
||||||
|
|
||||||
|
|
|
@ -38,6 +38,13 @@ def test_runtime_mem_tracer():
|
||||||
print("cuda_non_model_data_list", len(cuda_non_model_data_list))
|
print("cuda_non_model_data_list", len(cuda_non_model_data_list))
|
||||||
print(non_model_data_list)
|
print(non_model_data_list)
|
||||||
|
|
||||||
|
cnt1 = 0
|
||||||
|
for p in runtime_mem_tracer.parameters_in_runtime_order():
|
||||||
|
cnt1 += 1
|
||||||
|
cnt2 = 0
|
||||||
|
for p in model.parameters():
|
||||||
|
cnt2 += 1
|
||||||
|
assert cnt2 == cnt1, f'visited param number {cnt1} vs real param number {cnt2}'
|
||||||
del model
|
del model
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue