mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] update non model data calculation method (#2126)
parent
6c4c6a0409
commit
8fac837679
|
@ -11,13 +11,19 @@ class MemStats(object):
|
||||||
"""
|
"""
|
||||||
Store the non model data statistics used for Gemini and ZeroOptimizer.
|
Store the non model data statistics used for Gemini and ZeroOptimizer.
|
||||||
"""
|
"""
|
||||||
# p -> list of non_model data volumn visied in order.
|
# (preop_step, List[param])
|
||||||
|
|
||||||
# (preop_moment, List[param])
|
|
||||||
self._step_param_dict = dict()
|
self._step_param_dict = dict()
|
||||||
|
# (param, List[preop_step])
|
||||||
self._param_step_dict = dict()
|
self._param_step_dict = dict()
|
||||||
|
# (preop_step, non_model_data)
|
||||||
|
self._step_nmd_dict = dict()
|
||||||
|
self._param_runtime_order = OrderedParamGenerator()
|
||||||
|
|
||||||
# (param, List[preop_moment])
|
self._preop_step = 0
|
||||||
|
|
||||||
|
self._prev_overall_cuda = -1
|
||||||
|
self._prev_md_cuda = -1
|
||||||
|
# old version
|
||||||
self.param_non_model_data_map: Dict(Any, List[int]) = {}
|
self.param_non_model_data_map: Dict(Any, List[int]) = {}
|
||||||
|
|
||||||
self._model_data_cuda_list = []
|
self._model_data_cuda_list = []
|
||||||
|
@ -29,9 +35,15 @@ class MemStats(object):
|
||||||
self._non_model_data_cuda_list = []
|
self._non_model_data_cuda_list = []
|
||||||
self._non_model_data_cpu_list = []
|
self._non_model_data_cpu_list = []
|
||||||
|
|
||||||
self._param_runtime_order = OrderedParamGenerator()
|
def record_max_cuda_non_model_data(self):
|
||||||
|
if self._prev_overall_cuda != -1 and self._prev_md_cuda != -1:
|
||||||
|
self._step_nmd_dict[self._preop_step] = self._prev_overall_cuda - self._prev_md_cuda
|
||||||
|
|
||||||
self._preop_step = 0
|
def record_max_cuda_model_data(self, val):
|
||||||
|
self._prev_md_cuda = val
|
||||||
|
|
||||||
|
def record_max_cuda_overall_data(self, val):
|
||||||
|
self._prev_overall_cuda = val
|
||||||
|
|
||||||
def param_order(self):
|
def param_order(self):
|
||||||
if self._param_runtime_order.is_empty():
|
if self._param_runtime_order.is_empty():
|
||||||
|
@ -168,4 +180,8 @@ class MemStats(object):
|
||||||
self._param_runtime_order.clear()
|
self._param_runtime_order.clear()
|
||||||
self._step_param_dict.clear()
|
self._step_param_dict.clear()
|
||||||
self._param_step_dict.clear()
|
self._param_step_dict.clear()
|
||||||
|
self._step_nmd_dict.clear()
|
||||||
self._preop_step = 0
|
self._preop_step = 0
|
||||||
|
|
||||||
|
self._prev_overall_cuda = -1
|
||||||
|
self._prev_md_cuda = -1
|
||||||
|
|
|
@ -64,7 +64,16 @@ class ParamMemTracerHook(ColoParamOpHook):
|
||||||
raise NotImplementedError("Only free cuda memory")
|
raise NotImplementedError("Only free cuda memory")
|
||||||
free_storage(p.data)
|
free_storage(p.data)
|
||||||
|
|
||||||
def _allocate_params_on_cuda(self, params):
|
def _allocate_params_on_cuda(self, params: List[torch.nn.Parameter]):
|
||||||
|
"""
|
||||||
|
move params to cuda
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params (List[torch.nn.Parameter]): target params
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotImplementedError: raise error when param has cpu grad
|
||||||
|
"""
|
||||||
for p in params:
|
for p in params:
|
||||||
cur_dev = p.data.device.type
|
cur_dev = p.data.device.type
|
||||||
if cur_dev == "cpu":
|
if cur_dev == "cpu":
|
||||||
|
@ -78,6 +87,9 @@ class ParamMemTracerHook(ColoParamOpHook):
|
||||||
alloc_storage(p.data)
|
alloc_storage(p.data)
|
||||||
|
|
||||||
def sample_model_data(self, params):
|
def sample_model_data(self, params):
|
||||||
|
"""
|
||||||
|
get cuda model data used by params
|
||||||
|
"""
|
||||||
data_volume = self._grad_stats.unreleased_grad_volume
|
data_volume = self._grad_stats.unreleased_grad_volume
|
||||||
for p in params:
|
for p in params:
|
||||||
cur_model_data_volume = p.data.numel() * p.data.element_size()
|
cur_model_data_volume = p.data.numel() * p.data.element_size()
|
||||||
|
@ -89,14 +101,21 @@ class ParamMemTracerHook(ColoParamOpHook):
|
||||||
self._grad_stats.unreleased_grad_volume += cur_model_data_volume
|
self._grad_stats.unreleased_grad_volume += cur_model_data_volume
|
||||||
self._grad_stats.unreleased_grad_flag[p] = True
|
self._grad_stats.unreleased_grad_flag[p] = True
|
||||||
self._memstats.append_model_data('cuda', data_volume)
|
self._memstats.append_model_data('cuda', data_volume)
|
||||||
|
# record max non model data used for this Op
|
||||||
|
self._memstats.record_max_cuda_model_data(data_volume)
|
||||||
|
|
||||||
def pre_op(self, params):
|
def pre_op(self, params):
|
||||||
cuda_volume = self.mem_monitor.finish()
|
# get overall cuda data.
|
||||||
last_model_data_val = self._memstats.last_model_data('cuda')
|
max_cuda_vol_of_period = self.mem_monitor.finish()
|
||||||
if last_model_data_val is not None:
|
# record max cuda overall data for prev Op.
|
||||||
self._memstats.append_non_model_data('cuda', cuda_volume - last_model_data_val)
|
self._memstats.record_max_cuda_overall_data(max_cuda_vol_of_period)
|
||||||
|
self._memstats.record_max_cuda_non_model_data()
|
||||||
|
max_cuda_model_data_val = self._memstats.last_model_data('cuda')
|
||||||
|
if max_cuda_model_data_val is not None:
|
||||||
|
self._memstats.append_non_model_data('cuda', max_cuda_vol_of_period - max_cuda_model_data_val)
|
||||||
self._allocate_params_on_cuda(params)
|
self._allocate_params_on_cuda(params)
|
||||||
self.sample_model_data(params)
|
self.sample_model_data(params)
|
||||||
|
|
||||||
self.mem_monitor.start()
|
self.mem_monitor.start()
|
||||||
self._memstats.increase_preop_step(params)
|
self._memstats.increase_preop_step(params)
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,8 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
|
||||||
memstats = runtime_mem_tracer.memstats()
|
memstats = runtime_mem_tracer.memstats()
|
||||||
runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list
|
runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list
|
||||||
print('runtime tracer non model data points: ', len(runtime_tracer_non_model_data))
|
print('runtime tracer non model data points: ', len(runtime_tracer_non_model_data))
|
||||||
|
print('runtime tracer: ', runtime_tracer_non_model_data)
|
||||||
|
print([memstats.param_used_timestep(p) for p in model.parameters()])
|
||||||
|
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
config_dict, _ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||||
|
|
Loading…
Reference in New Issue