mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] update the non model data record method in runtime memory tracer (#2128)
parent
deee317b0f
commit
2938edf446
|
@ -133,9 +133,9 @@ class GeminiManager:
|
||||||
if self._mem_stats_collector:
|
if self._mem_stats_collector:
|
||||||
self._mem_stats_collector.sample_overall_data()
|
self._mem_stats_collector.sample_overall_data()
|
||||||
|
|
||||||
def sample_model_data(self):
|
def record_model_data_volume(self):
|
||||||
if self._mem_stats_collector:
|
if self._mem_stats_collector:
|
||||||
self._mem_stats_collector.sample_model_data()
|
self._mem_stats_collector.record_model_data_volume()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def chunk_manager(self):
|
def chunk_manager(self):
|
||||||
|
|
|
@ -15,7 +15,7 @@ class ChunkMemStatsCollector(MemStatsCollector):
|
||||||
self._chunk_manager = chunk_manager
|
self._chunk_manager = chunk_manager
|
||||||
|
|
||||||
# override
|
# override
|
||||||
def sample_model_data(self) -> None:
|
def record_model_data_volume(self) -> None:
|
||||||
"""Sampling model data statistics.
|
"""Sampling model data statistics.
|
||||||
"""
|
"""
|
||||||
if self._start_flag and not self.use_outside_memstats:
|
if self._start_flag and not self.use_outside_memstats:
|
||||||
|
|
|
@ -15,7 +15,7 @@ class MemStats(object):
|
||||||
self._step_param_dict = dict()
|
self._step_param_dict = dict()
|
||||||
# (param, List[preop_step])
|
# (param, List[preop_step])
|
||||||
self._param_step_dict = dict()
|
self._param_step_dict = dict()
|
||||||
# (preop_step, non_model_data)
|
# (preop_step, non_model_data) non model data used during preop_step ~ (preop_step+1)
|
||||||
self._step_nmd_dict = dict()
|
self._step_nmd_dict = dict()
|
||||||
self._param_runtime_order = OrderedParamGenerator()
|
self._param_runtime_order = OrderedParamGenerator()
|
||||||
|
|
||||||
|
@ -23,9 +23,8 @@ class MemStats(object):
|
||||||
|
|
||||||
self._prev_overall_cuda = -1
|
self._prev_overall_cuda = -1
|
||||||
self._prev_md_cuda = -1
|
self._prev_md_cuda = -1
|
||||||
# old version
|
|
||||||
self.param_non_model_data_map: Dict(Any, List[int]) = {}
|
|
||||||
|
|
||||||
|
# old version
|
||||||
self._model_data_cuda_list = []
|
self._model_data_cuda_list = []
|
||||||
self._model_data_cpu_list = []
|
self._model_data_cpu_list = []
|
||||||
|
|
||||||
|
@ -35,9 +34,12 @@ class MemStats(object):
|
||||||
self._non_model_data_cuda_list = []
|
self._non_model_data_cuda_list = []
|
||||||
self._non_model_data_cpu_list = []
|
self._non_model_data_cpu_list = []
|
||||||
|
|
||||||
def record_max_cuda_non_model_data(self):
|
def calc_max_cuda_non_model_data(self):
|
||||||
if self._prev_overall_cuda != -1 and self._prev_md_cuda != -1:
|
if self._prev_overall_cuda != -1 and self._prev_md_cuda != -1:
|
||||||
self._step_nmd_dict[self._preop_step] = self._prev_overall_cuda - self._prev_md_cuda
|
max_cuda_non_model_data = self._prev_overall_cuda - self._prev_md_cuda
|
||||||
|
self._step_nmd_dict[self._preop_step - 1] = max_cuda_non_model_data
|
||||||
|
# compatibility of the old version.
|
||||||
|
self._non_model_data_cuda_list.append(max_cuda_non_model_data)
|
||||||
|
|
||||||
def record_max_cuda_model_data(self, val):
|
def record_max_cuda_model_data(self, val):
|
||||||
self._prev_md_cuda = val
|
self._prev_md_cuda = val
|
||||||
|
@ -45,12 +47,45 @@ class MemStats(object):
|
||||||
def record_max_cuda_overall_data(self, val):
|
def record_max_cuda_overall_data(self, val):
|
||||||
self._prev_overall_cuda = val
|
self._prev_overall_cuda = val
|
||||||
|
|
||||||
|
def increase_preop_step(self, param_list: List[torch.nn.Parameter]):
|
||||||
|
"""
|
||||||
|
the time step is increased. param list is used between current and the next
|
||||||
|
time step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param_list (List[torch.nn.Parameter]): a list of torch paramters.
|
||||||
|
"""
|
||||||
|
for p in param_list:
|
||||||
|
if p not in self._param_step_dict:
|
||||||
|
self._param_step_dict[p] = [self._preop_step]
|
||||||
|
else:
|
||||||
|
self._param_step_dict[p].append(self._preop_step)
|
||||||
|
self._param_runtime_order.append(p)
|
||||||
|
self._step_param_dict[self._preop_step] = param_list
|
||||||
|
self._preop_step += 1
|
||||||
|
|
||||||
|
def param_used_step(self, param: torch.nn.Parameter) -> Optional[List[int]]:
|
||||||
|
"""param_used_step
|
||||||
|
get the timestep list using the param
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param (torch.nn.Parameter): a torch param
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[List[int]]: a list of int indicates the time step of preop hook.
|
||||||
|
"""
|
||||||
|
if param not in self._param_step_dict:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return self._param_step_dict[param]
|
||||||
|
|
||||||
def param_order(self):
|
def param_order(self):
|
||||||
if self._param_runtime_order.is_empty():
|
if self._param_runtime_order.is_empty():
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
else:
|
else:
|
||||||
return self._param_runtime_order
|
return self._param_runtime_order
|
||||||
|
|
||||||
|
## APIs to be depracated
|
||||||
def append_overall_data(self, device_type: str, val: float):
|
def append_overall_data(self, device_type: str, val: float):
|
||||||
if device_type == 'cuda':
|
if device_type == 'cuda':
|
||||||
self._overall_cuda_list.append(val)
|
self._overall_cuda_list.append(val)
|
||||||
|
@ -135,38 +170,6 @@ class MemStats(object):
|
||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
def increase_preop_step(self, param_list: List[torch.nn.Parameter]):
|
|
||||||
"""
|
|
||||||
the time step is increased. param list is used between current and the next
|
|
||||||
time step.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
param_list (List[torch.nn.Parameter]): a list of torch paramters.
|
|
||||||
"""
|
|
||||||
for p in param_list:
|
|
||||||
if p not in self._param_step_dict:
|
|
||||||
self._param_step_dict[p] = [self._preop_step]
|
|
||||||
else:
|
|
||||||
self._param_step_dict[p].append(self._preop_step)
|
|
||||||
self._param_runtime_order.append(p)
|
|
||||||
self._step_param_dict[self._preop_step] = param_list
|
|
||||||
self._preop_step += 1
|
|
||||||
|
|
||||||
def param_used_timestep(self, param: torch.nn.Parameter) -> Optional[List[int]]:
|
|
||||||
"""param_used_timestep
|
|
||||||
get the timestep list using the param
|
|
||||||
|
|
||||||
Args:
|
|
||||||
param (torch.nn.Parameter): a torch param
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[List[int]]: a list of int indicates the time step of preop hook.
|
|
||||||
"""
|
|
||||||
if param not in self._param_step_dict:
|
|
||||||
return None
|
|
||||||
else:
|
|
||||||
return self._param_step_dict[param]
|
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self._model_data_cuda_list = []
|
self._model_data_cuda_list = []
|
||||||
self._overall_cuda_list = []
|
self._overall_cuda_list = []
|
||||||
|
|
|
@ -69,7 +69,7 @@ class MemStatsCollector:
|
||||||
self._start_flag = False
|
self._start_flag = False
|
||||||
self._mem_monitor.finish()
|
self._mem_monitor.finish()
|
||||||
|
|
||||||
def sample_model_data(self) -> None:
|
def record_model_data_volume(self) -> None:
|
||||||
"""Sampling model data statistics.
|
"""Sampling model data statistics.
|
||||||
"""
|
"""
|
||||||
if self._start_flag and not self.use_outside_memstats:
|
if self._start_flag and not self.use_outside_memstats:
|
||||||
|
|
|
@ -82,7 +82,9 @@ class RuntimeMemTracer():
|
||||||
|
|
||||||
def _post_backward(self):
|
def _post_backward(self):
|
||||||
cuda_volume = self.param_op_hook.mem_monitor.finish()
|
cuda_volume = self.param_op_hook.mem_monitor.finish()
|
||||||
self._memstats.append_non_model_data('cuda', cuda_volume - self._memstats.last_model_data('cuda'))
|
self._memstats.record_max_cuda_overall_data(cuda_volume)
|
||||||
|
# calc the last Op non model data
|
||||||
|
self._memstats.calc_max_cuda_non_model_data()
|
||||||
self.grad_hook.remove_grad_hook()
|
self.grad_hook.remove_grad_hook()
|
||||||
self._restore_params()
|
self._restore_params()
|
||||||
|
|
||||||
|
|
|
@ -86,7 +86,7 @@ class ParamMemTracerHook(ColoParamOpHook):
|
||||||
elif cur_dev == "cuda":
|
elif cur_dev == "cuda":
|
||||||
alloc_storage(p.data)
|
alloc_storage(p.data)
|
||||||
|
|
||||||
def sample_model_data(self, params):
|
def record_model_data_volume(self, params):
|
||||||
"""
|
"""
|
||||||
get cuda model data used by params
|
get cuda model data used by params
|
||||||
"""
|
"""
|
||||||
|
@ -100,21 +100,19 @@ class ParamMemTracerHook(ColoParamOpHook):
|
||||||
if not self._grad_stats.unreleased_grad_flag[p]:
|
if not self._grad_stats.unreleased_grad_flag[p]:
|
||||||
self._grad_stats.unreleased_grad_volume += cur_model_data_volume
|
self._grad_stats.unreleased_grad_volume += cur_model_data_volume
|
||||||
self._grad_stats.unreleased_grad_flag[p] = True
|
self._grad_stats.unreleased_grad_flag[p] = True
|
||||||
self._memstats.append_model_data('cuda', data_volume)
|
|
||||||
# record max non model data used for this Op
|
# record max non model data used for this Op
|
||||||
self._memstats.record_max_cuda_model_data(data_volume)
|
self._memstats.record_max_cuda_model_data(data_volume)
|
||||||
|
|
||||||
def pre_op(self, params):
|
def pre_op(self, params):
|
||||||
# get overall cuda data.
|
max_cuda_used_pre_op = self.mem_monitor.finish()
|
||||||
max_cuda_vol_of_period = self.mem_monitor.finish()
|
# record max cuda overall data for prev OP.
|
||||||
# record max cuda overall data for prev Op.
|
self._memstats.record_max_cuda_overall_data(max_cuda_used_pre_op)
|
||||||
self._memstats.record_max_cuda_overall_data(max_cuda_vol_of_period)
|
# record max cuda non model data for prev OP.
|
||||||
self._memstats.record_max_cuda_non_model_data()
|
self._memstats.calc_max_cuda_non_model_data()
|
||||||
max_cuda_model_data_val = self._memstats.last_model_data('cuda')
|
|
||||||
if max_cuda_model_data_val is not None:
|
|
||||||
self._memstats.append_non_model_data('cuda', max_cuda_vol_of_period - max_cuda_model_data_val)
|
|
||||||
self._allocate_params_on_cuda(params)
|
self._allocate_params_on_cuda(params)
|
||||||
self.sample_model_data(params)
|
# record max cuda model data for current OP
|
||||||
|
self.record_model_data_volume(params)
|
||||||
|
|
||||||
self.mem_monitor.start()
|
self.mem_monitor.start()
|
||||||
self._memstats.increase_preop_step(params)
|
self._memstats.increase_preop_step(params)
|
||||||
|
|
|
@ -32,7 +32,7 @@ class GeminiZeROHook(ColoParamOpHook):
|
||||||
self._gemini_manager.adjust_layout(chunks)
|
self._gemini_manager.adjust_layout(chunks)
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
self._chunk_manager.access_chunk(chunk)
|
self._chunk_manager.access_chunk(chunk)
|
||||||
self._gemini_manager.sample_model_data()
|
self._gemini_manager.record_model_data_volume()
|
||||||
|
|
||||||
def post_op(self, params):
|
def post_op(self, params):
|
||||||
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
|
params = [p for p in params if not getattr(p, '_ddp_to_ignore', False)]
|
||||||
|
|
|
@ -67,7 +67,7 @@ class ZeroHook(BaseOpHook):
|
||||||
|
|
||||||
# record model data statistics
|
# record model data statistics
|
||||||
if self._memstarts_collector:
|
if self._memstarts_collector:
|
||||||
self._memstarts_collector.sample_model_data()
|
self._memstarts_collector.record_model_data_volume()
|
||||||
|
|
||||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||||
self.adjust_module_data(module)
|
self.adjust_module_data(module)
|
||||||
|
|
|
@ -47,7 +47,13 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
|
||||||
runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list
|
runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list
|
||||||
print('runtime tracer non model data points: ', len(runtime_tracer_non_model_data))
|
print('runtime tracer non model data points: ', len(runtime_tracer_non_model_data))
|
||||||
print('runtime tracer: ', runtime_tracer_non_model_data)
|
print('runtime tracer: ', runtime_tracer_non_model_data)
|
||||||
print([memstats.param_used_timestep(p) for p in model.parameters()])
|
print([memstats.param_used_step(p) for p in model.parameters()])
|
||||||
|
|
||||||
|
if model_name == 'repeated_computed_layers':
|
||||||
|
for idx, p in enumerate(model.parameters()):
|
||||||
|
step_list = memstats.param_used_step(p)
|
||||||
|
if idx < 4:
|
||||||
|
assert len(step_list) == 4
|
||||||
|
|
||||||
if model_name == 'repeated_computed_layers':
|
if model_name == 'repeated_computed_layers':
|
||||||
for idx, p in enumerate(model.parameters()):
|
for idx, p in enumerate(model.parameters()):
|
||||||
|
|
Loading…
Reference in New Issue