[gemini] get the param visited order during runtime (#2108)

pull/2113/head
Jiarui Fang 2022-12-09 16:13:03 +08:00 committed by GitHub
parent 61f31c3cf0
commit 70a8556946
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 48 additions and 2 deletions

View File

@ -1,3 +1,4 @@
from .param_runtime_order import ParamRuntimeOrder # isort:skip
from .memory_stats import MemStats # isort:skip
from .memory_monitor import AsyncMemoryMonitor, SyncCudaMemoryMonitor # isort:skip
from .memstats_collector import MemStatsCollector # isort:skip
@ -6,5 +7,5 @@ from .static_memstats_collector import StaticMemStatsCollector # isort:skip
__all__ = [
'AsyncMemoryMonitor', 'SyncCudaMemoryMonitor', 'MemStatsCollector', 'ChunkMemStatsCollector',
'StaticMemStatsCollector', 'MemStats'
'StaticMemStatsCollector', 'MemStats', 'ParamRuntimeOrder'
]

View File

@ -1,5 +1,7 @@
from typing import Any, Dict, List
from colossalai.gemini.memory_tracer import ParamRuntimeOrder
class MemStats(object):
@ -19,6 +21,8 @@ class MemStats(object):
self._non_model_data_cuda_list = []
self._non_model_data_cpu_list = []
self._param_runtime_order = ParamRuntimeOrder()
def append_overall_data(self, device_type: str, val: float):
if device_type == 'cuda':
self._overall_cuda_list.append(val)
@ -112,3 +116,5 @@ class MemStats(object):
self._non_model_data_cpu_list = []
self._non_model_data_cuda_list = []
self._param_runtime_order.clear()

View File

@ -0,0 +1,25 @@
import torch
class ParamRuntimeOrder(object):
"""ParamRuntimeOrder
Contain the order of parameters visited during runtime.
"""
def __init__(self) -> None:
self.param_visited_order = []
def append(self, param: torch.nn.Parameter):
self.param_visited_order.append(param)
def generate(self):
visited_set = set()
for p in self.param_visited_order:
if p not in visited_set:
yield p
visited_set.add(p)
del visited_set
def clear(self):
self.param_visited_order = []

View File

@ -1,6 +1,6 @@
import torch.nn
from colossalai.gemini.memory_tracer import MemStats
from colossalai.gemini.memory_tracer import MemStats, ParamRuntimeOrder
from colossalai.gemini.ophooks.runtime_mem_tracer_hook import GradMemStats, GradMemTracerHook, ParamMemTracerHook
from colossalai.nn.parallel.data_parallel import _cast_float
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
@ -35,6 +35,9 @@ class RuntimeMemTracer():
self._cast_buffers_to_cuda_dtype()
def parameters_in_runtime_order(self):
return self._memstats._param_runtime_order.generate()
def memstats(self):
return self._memstats

View File

@ -99,6 +99,10 @@ class ParamMemTracerHook(ColoParamOpHook):
self.sample_model_data(params)
self.mem_monitor.start()
# register the order of visited.
for p in params:
self._memstats._param_runtime_order.append(p)
def post_op(self, params):
self._free_cuda_params(params)

View File

@ -38,6 +38,13 @@ def test_runtime_mem_tracer():
print("cuda_non_model_data_list", len(cuda_non_model_data_list))
print(non_model_data_list)
cnt1 = 0
for p in runtime_mem_tracer.parameters_in_runtime_order():
cnt1 += 1
cnt2 = 0
for p in model.parameters():
cnt2 += 1
assert cnt2 == cnt1, f'visited param number {cnt1} vs real param number {cnt2}'
del model