[Gemini] mapping of preop timestep and param (#2124)

pull/2125/head
Jiarui Fang 2022-12-13 12:50:24 +08:00 committed by GitHub
parent 764bc16f3e
commit 05bb28aacf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 49 additions and 6 deletions

View File

@ -1,4 +1,6 @@
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
import torch
from colossalai.gemini.memory_tracer import OrderedParamGenerator
@ -10,6 +12,12 @@ class MemStats(object):
Store the non model data statistics used for Gemini and ZeroOptimizer.
"""
# p -> list of non_model data volumn visied in order.
# (preop_moment, List[param])
self._step_param_dict = dict()
self._param_step_dict = dict()
# (param, List[preop_moment])
self.param_non_model_data_map: Dict(Any, List[int]) = {}
self._model_data_cuda_list = []
@ -23,6 +31,8 @@ class MemStats(object):
self._param_runtime_order = OrderedParamGenerator()
self._preop_step = 0
def param_order(self):
if self._param_runtime_order.is_empty():
raise RuntimeError
@ -113,6 +123,38 @@ class MemStats(object):
else:
raise TypeError
def increase_preop_step(self, param_list: List[torch.nn.Parameter]):
"""
the time step is increased. param list is used between current and the next
time step.
Args:
param_list (List[torch.nn.Parameter]): a list of torch paramters.
"""
for p in param_list:
if p not in self._param_step_dict:
self._param_step_dict[p] = [self._preop_step]
else:
self._param_step_dict[p].append(self._preop_step)
self._param_runtime_order.append(p)
self._step_param_dict[self._preop_step] = param_list
self._preop_step += 1
def param_used_timestep(self, param: torch.nn.Parameter) -> Optional[List[int]]:
"""param_used_timestep
get the timestep list using the param
Args:
param (torch.nn.Parameter): a torch param
Returns:
Optional[List[int]]: a list of int indicates the time step of preop hook.
"""
if param not in self._param_step_dict:
return None
else:
return self._param_step_dict[param]
def clear(self):
self._model_data_cuda_list = []
self._overall_cuda_list = []
@ -124,3 +166,6 @@ class MemStats(object):
self._non_model_data_cuda_list = []
self._param_runtime_order.clear()
self._step_param_dict.clear()
self._param_step_dict.clear()
self._preop_step = 0

View File

@ -98,10 +98,7 @@ class ParamMemTracerHook(ColoParamOpHook):
self._allocate_params_on_cuda(params)
self.sample_model_data(params)
self.mem_monitor.start()
# register the order of visited.
for p in params:
self._memstats._param_runtime_order.append(p)
self._memstats.increase_preop_step(params)
def post_op(self, params):
self._free_cuda_params(params)

View File

@ -45,7 +45,8 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer)
memstats = runtime_mem_tracer.memstats()
runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list
print('runtime tracer non model data points: ', len(runtime_tracer_non_model_data))
print('runtime tracer: ', runtime_tracer_non_model_data)
print([memstats.param_used_timestep(p) for p in model.parameters()])
model = GeminiDDP(model, device='cuda', placement_policy=placement_policy, search_range_mb=1, memstats=memstats)
zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1)