[Gemini] update non model data calculation method (#2126)

pull/2127/head^2
Jiarui Fang 2 years ago committed by GitHub
parent 6c4c6a0409
commit 8fac837679
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,13 +11,19 @@ class MemStats(object):
"""
Store the non model data statistics used for Gemini and ZeroOptimizer.
"""
# p -> list of non_model data volumn visied in order.
# (preop_moment, List[param])
# (preop_step, List[param])
self._step_param_dict = dict()
# (param, List[preop_step])
self._param_step_dict = dict()
# (preop_step, non_model_data)
self._step_nmd_dict = dict()
self._param_runtime_order = OrderedParamGenerator()
self._preop_step = 0
# (param, List[preop_moment])
self._prev_overall_cuda = -1
self._prev_md_cuda = -1
# old version
self.param_non_model_data_map: Dict(Any, List[int]) = {}
self._model_data_cuda_list = []
@ -29,9 +35,15 @@ class MemStats(object):
self._non_model_data_cuda_list = []
self._non_model_data_cpu_list = []
self._param_runtime_order = OrderedParamGenerator()
def record_max_cuda_non_model_data(self):
if self._prev_overall_cuda != -1 and self._prev_md_cuda != -1:
self._step_nmd_dict[self._preop_step] = self._prev_overall_cuda - self._prev_md_cuda
self._preop_step = 0
def record_max_cuda_model_data(self, val):
self._prev_md_cuda = val
def record_max_cuda_overall_data(self, val):
self._prev_overall_cuda = val
def param_order(self):
if self._param_runtime_order.is_empty():
@ -168,4 +180,8 @@ class MemStats(object):
self._param_runtime_order.clear()
self._step_param_dict.clear()
self._param_step_dict.clear()
self._step_nmd_dict.clear()
self._preop_step = 0
self._prev_overall_cuda = -1
self._prev_md_cuda = -1

@ -64,7 +64,16 @@ class ParamMemTracerHook(ColoParamOpHook):
raise NotImplementedError("Only free cuda memory")
free_storage(p.data)
def _allocate_params_on_cuda(self, params):
def _allocate_params_on_cuda(self, params: List[torch.nn.Parameter]):
"""
move params to cuda
Args:
params (List[torch.nn.Parameter]): target params
Raises:
NotImplementedError: raise error when param has cpu grad
"""
for p in params:
cur_dev = p.data.device.type
if cur_dev == "cpu":
@ -78,6 +87,9 @@ class ParamMemTracerHook(ColoParamOpHook):
alloc_storage(p.data)
def sample_model_data(self, params):
"""
get cuda model data used by params
"""
data_volume = self._grad_stats.unreleased_grad_volume
for p in params:
cur_model_data_volume = p.data.numel() * p.data.element_size()
@ -89,14 +101,21 @@ class ParamMemTracerHook(ColoParamOpHook):
self._grad_stats.unreleased_grad_volume += cur_model_data_volume
self._grad_stats.unreleased_grad_flag[p] = True
self._memstats.append_model_data('cuda', data_volume)
# record max non model data used for this Op
self._memstats.record_max_cuda_model_data(data_volume)
def pre_op(self, params):
cuda_volume = self.mem_monitor.finish()
last_model_data_val = self._memstats.last_model_data('cuda')
if last_model_data_val is not None:
self._memstats.append_non_model_data('cuda', cuda_volume - last_model_data_val)
# get overall cuda data.
max_cuda_vol_of_period = self.mem_monitor.finish()
# record max cuda overall data for prev Op.
self._memstats.record_max_cuda_overall_data(max_cuda_vol_of_period)
self._memstats.record_max_cuda_non_model_data()
max_cuda_model_data_val = self._memstats.last_model_data('cuda')
if max_cuda_model_data_val is not None:
self._memstats.append_non_model_data('cuda', max_cuda_vol_of_period - max_cuda_model_data_val)
self._allocate_params_on_cuda(params)
self.sample_model_data(params)
self.mem_monitor.start()
self._memstats.increase_preop_step(params)

@ -46,6 +46,8 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
memstats = runtime_mem_tracer.memstats()
runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list
print('runtime tracer non model data points: ', len(runtime_tracer_non_model_data))
print('runtime tracer: ', runtime_tracer_non_model_data)
print([memstats.param_used_timestep(p) for p in model.parameters()])
world_size = torch.distributed.get_world_size()
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)

Loading…
Cancel
Save