From 05bb28aacf71cde64bfc94e5ad7555d5607da77f Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Tue, 13 Dec 2022 12:50:24 +0800 Subject: [PATCH] [Gemini] mapping of preop timestep and param (#2124) --- .../gemini/memory_tracer/memory_stats.py | 47 ++++++++++++++++++- .../gemini/ophooks/runtime_mem_tracer_hook.py | 5 +- .../test_gemini/update/test_gemini_use_rmt.py | 3 +- 3 files changed, 49 insertions(+), 6 deletions(-) diff --git a/colossalai/gemini/memory_tracer/memory_stats.py b/colossalai/gemini/memory_tracer/memory_stats.py index a66829863..5338fb50a 100644 --- a/colossalai/gemini/memory_tracer/memory_stats.py +++ b/colossalai/gemini/memory_tracer/memory_stats.py @@ -1,4 +1,6 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional + +import torch from colossalai.gemini.memory_tracer import OrderedParamGenerator @@ -10,6 +12,12 @@ class MemStats(object): Store the non model data statistics used for Gemini and ZeroOptimizer. """ # p -> list of non_model data volumn visied in order. + + # (preop_moment, List[param]) + self._step_param_dict = dict() + self._param_step_dict = dict() + + # (param, List[preop_moment]) self.param_non_model_data_map: Dict(Any, List[int]) = {} self._model_data_cuda_list = [] @@ -23,6 +31,8 @@ class MemStats(object): self._param_runtime_order = OrderedParamGenerator() + self._preop_step = 0 + def param_order(self): if self._param_runtime_order.is_empty(): raise RuntimeError @@ -113,6 +123,38 @@ class MemStats(object): else: raise TypeError + def increase_preop_step(self, param_list: List[torch.nn.Parameter]): + """ + the time step is increased. param list is used between current and the next + time step. + + Args: + param_list (List[torch.nn.Parameter]): a list of torch paramters. + """ + for p in param_list: + if p not in self._param_step_dict: + self._param_step_dict[p] = [self._preop_step] + else: + self._param_step_dict[p].append(self._preop_step) + self._param_runtime_order.append(p) + self._step_param_dict[self._preop_step] = param_list + self._preop_step += 1 + + def param_used_timestep(self, param: torch.nn.Parameter) -> Optional[List[int]]: + """param_used_timestep + get the timestep list using the param + + Args: + param (torch.nn.Parameter): a torch param + + Returns: + Optional[List[int]]: a list of int indicates the time step of preop hook. + """ + if param not in self._param_step_dict: + return None + else: + return self._param_step_dict[param] + def clear(self): self._model_data_cuda_list = [] self._overall_cuda_list = [] @@ -124,3 +166,6 @@ class MemStats(object): self._non_model_data_cuda_list = [] self._param_runtime_order.clear() + self._step_param_dict.clear() + self._param_step_dict.clear() + self._preop_step = 0 diff --git a/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py b/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py index faba1e22a..a5e47000b 100644 --- a/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py +++ b/colossalai/gemini/ophooks/runtime_mem_tracer_hook.py @@ -98,10 +98,7 @@ class ParamMemTracerHook(ColoParamOpHook): self._allocate_params_on_cuda(params) self.sample_model_data(params) self.mem_monitor.start() - - # register the order of visited. - for p in params: - self._memstats._param_runtime_order.append(p) + self._memstats.increase_preop_step(params) def post_op(self, params): self._free_cuda_params(params) diff --git a/tests/test_gemini/update/test_gemini_use_rmt.py b/tests/test_gemini/update/test_gemini_use_rmt.py index 5a8f066ac..3e3247e39 100644 --- a/tests/test_gemini/update/test_gemini_use_rmt.py +++ b/tests/test_gemini/update/test_gemini_use_rmt.py @@ -45,7 +45,8 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer) memstats = runtime_mem_tracer.memstats() runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list - print('runtime tracer non model data points: ', len(runtime_tracer_non_model_data)) + print('runtime tracer: ', runtime_tracer_non_model_data) + print([memstats.param_used_timestep(p) for p in model.parameters()]) model = GeminiDDP(model, device='cuda', placement_policy=placement_policy, search_range_mb=1, memstats=memstats) zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1)