From 8fac837679b93e3105e585b710a43521ada6b2a2 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 13 Dec 2022 15:44:07 +0800 Subject: [PATCH] [Gemini] update non model data calculation method (#2126) --- .../gemini/memory_tracer/memory_stats.py | 28 ++++++++++++++---- .../gemini/ophooks/runtime_mem_tracer_hook.py | 29 +++++++++++++++---- .../test_gemini/update/test_gemini_use_rmt.py | 2 ++ 3 files changed, 48 insertions(+), 11 deletions(-) diff --git a/colossalai/gemini/memory_tracer/memory_stats.py b/colossalai/gemini/memory_tracer/memory_stats.py index 5338fb50a..bc215ccb9 100644 --- a/colossalai/gemini/memory_tracer/memory_stats.py +++ b/colossalai/gemini/memory_tracer/memory_stats.py @@ -11,13 +11,19 @@ class MemStats(object): """ Store the non model data statistics used for Gemini and ZeroOptimizer. """ - # p -> list of non_model data volumn visied in order. - - # (preop_moment, List[param]) + # (preop_step, List[param]) self._step_param_dict = dict() + # (param, List[preop_step]) self._param_step_dict = dict() + # (preop_step, non_model_data) + self._step_nmd_dict = dict() + self._param_runtime_order = OrderedParamGenerator() + + self._preop_step = 0 - # (param, List[preop_moment]) + self._prev_overall_cuda = -1 + self._prev_md_cuda = -1 + # old version self.param_non_model_data_map: Dict(Any, List[int]) = {} self._model_data_cuda_list = [] @@ -29,9 +35,15 @@ class MemStats(object): self._non_model_data_cuda_list = [] self._non_model_data_cpu_list = [] - self._param_runtime_order = OrderedParamGenerator() + def record_max_cuda_non_model_data(self): + if self._prev_overall_cuda != -1 and self._prev_md_cuda != -1: + self._step_nmd_dict[self._preop_step] = self._prev_overall_cuda - self._prev_md_cuda - self._preop_step = 0 + def record_max_cuda_model_data(self, val): + self._prev_md_cuda = val + + def record_max_cuda_overall_data(self, val): + self._prev_overall_cuda = val def param_order(self): if self._param_runtime_order.is_empty(): @@ -168,4 +180,8 @@ class MemStats(object): self._param_runtime_order.clear() self._step_param_dict.clear() self._param_step_dict.clear() + self._step_nmd_dict.clear() self._preop_step = 0 + + self._prev_overall_cuda = -1 + self._prev_md_cuda = -1 diff --git a/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py b/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py index a5e47000b..1ff259762 100644 --- a/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py +++ b/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py @@ -64,7 +64,16 @@ class ParamMemTracerHook(ColoParamOpHook): raise NotImplementedError("Only free cuda memory") free_storage(p.data) - def _allocate_params_on_cuda(self, params): + def _allocate_params_on_cuda(self, params: List[torch.nn.Parameter]): + """ + move params to cuda + + Args: + params (List[torch.nn.Parameter]): target params + + Raises: + NotImplementedError: raise error when param has cpu grad + """ for p in params: cur_dev = p.data.device.type if cur_dev == "cpu": @@ -78,6 +87,9 @@ class ParamMemTracerHook(ColoParamOpHook): alloc_storage(p.data) def sample_model_data(self, params): + """ + get cuda model data used by params + """ data_volume = self._grad_stats.unreleased_grad_volume for p in params: cur_model_data_volume = p.data.numel() * p.data.element_size() @@ -89,14 +101,21 @@ class ParamMemTracerHook(ColoParamOpHook): self._grad_stats.unreleased_grad_volume += cur_model_data_volume self._grad_stats.unreleased_grad_flag[p] = True self._memstats.append_model_data('cuda', data_volume) + # record max non model data used for this Op + self._memstats.record_max_cuda_model_data(data_volume) def pre_op(self, params): - cuda_volume = self.mem_monitor.finish() - last_model_data_val = self._memstats.last_model_data('cuda') - if last_model_data_val is not None: - self._memstats.append_non_model_data('cuda', cuda_volume - last_model_data_val) + # get overall cuda data. + max_cuda_vol_of_period = self.mem_monitor.finish() + # record max cuda overall data for prev Op. + self._memstats.record_max_cuda_overall_data(max_cuda_vol_of_period) + self._memstats.record_max_cuda_non_model_data() + max_cuda_model_data_val = self._memstats.last_model_data('cuda') + if max_cuda_model_data_val is not None: + self._memstats.append_non_model_data('cuda', max_cuda_vol_of_period - max_cuda_model_data_val) self._allocate_params_on_cuda(params) self.sample_model_data(params) + self.mem_monitor.start() self._memstats.increase_preop_step(params) diff --git a/tests/test_gemini/update/test_gemini_use_rmt.py b/tests/test_gemini/update/test_gemini_use_rmt.py index 926b61ef4..518c22fdb 100644 --- a/tests/test_gemini/update/test_gemini_use_rmt.py +++ b/tests/test_gemini/update/test_gemini_use_rmt.py @@ -46,6 +46,8 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ memstats = runtime_mem_tracer.memstats() runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list print('runtime tracer non model data points: ', len(runtime_tracer_non_model_data)) + print('runtime tracer: ', runtime_tracer_non_model_data) + print([memstats.param_used_timestep(p) for p in model.parameters()]) world_size = torch.distributed.get_world_size() config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)