mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] mapping of preop timestep and param (#2124)
parent
764bc16f3e
commit
05bb28aacf
|
@ -1,4 +1,6 @@
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from colossalai.gemini.memory_tracer import OrderedParamGenerator
|
from colossalai.gemini.memory_tracer import OrderedParamGenerator
|
||||||
|
|
||||||
|
@ -10,6 +12,12 @@ class MemStats(object):
|
||||||
Store the non model data statistics used for Gemini and ZeroOptimizer.
|
Store the non model data statistics used for Gemini and ZeroOptimizer.
|
||||||
"""
|
"""
|
||||||
# p -> list of non_model data volumn visied in order.
|
# p -> list of non_model data volumn visied in order.
|
||||||
|
|
||||||
|
# (preop_moment, List[param])
|
||||||
|
self._step_param_dict = dict()
|
||||||
|
self._param_step_dict = dict()
|
||||||
|
|
||||||
|
# (param, List[preop_moment])
|
||||||
self.param_non_model_data_map: Dict(Any, List[int]) = {}
|
self.param_non_model_data_map: Dict(Any, List[int]) = {}
|
||||||
|
|
||||||
self._model_data_cuda_list = []
|
self._model_data_cuda_list = []
|
||||||
|
@ -23,6 +31,8 @@ class MemStats(object):
|
||||||
|
|
||||||
self._param_runtime_order = OrderedParamGenerator()
|
self._param_runtime_order = OrderedParamGenerator()
|
||||||
|
|
||||||
|
self._preop_step = 0
|
||||||
|
|
||||||
def param_order(self):
|
def param_order(self):
|
||||||
if self._param_runtime_order.is_empty():
|
if self._param_runtime_order.is_empty():
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
|
@ -113,6 +123,38 @@ class MemStats(object):
|
||||||
else:
|
else:
|
||||||
raise TypeError
|
raise TypeError
|
||||||
|
|
||||||
|
def increase_preop_step(self, param_list: List[torch.nn.Parameter]):
|
||||||
|
"""
|
||||||
|
the time step is increased. param list is used between current and the next
|
||||||
|
time step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param_list (List[torch.nn.Parameter]): a list of torch paramters.
|
||||||
|
"""
|
||||||
|
for p in param_list:
|
||||||
|
if p not in self._param_step_dict:
|
||||||
|
self._param_step_dict[p] = [self._preop_step]
|
||||||
|
else:
|
||||||
|
self._param_step_dict[p].append(self._preop_step)
|
||||||
|
self._param_runtime_order.append(p)
|
||||||
|
self._step_param_dict[self._preop_step] = param_list
|
||||||
|
self._preop_step += 1
|
||||||
|
|
||||||
|
def param_used_timestep(self, param: torch.nn.Parameter) -> Optional[List[int]]:
|
||||||
|
"""param_used_timestep
|
||||||
|
get the timestep list using the param
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param (torch.nn.Parameter): a torch param
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[List[int]]: a list of int indicates the time step of preop hook.
|
||||||
|
"""
|
||||||
|
if param not in self._param_step_dict:
|
||||||
|
return None
|
||||||
|
else:
|
||||||
|
return self._param_step_dict[param]
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self._model_data_cuda_list = []
|
self._model_data_cuda_list = []
|
||||||
self._overall_cuda_list = []
|
self._overall_cuda_list = []
|
||||||
|
@ -124,3 +166,6 @@ class MemStats(object):
|
||||||
self._non_model_data_cuda_list = []
|
self._non_model_data_cuda_list = []
|
||||||
|
|
||||||
self._param_runtime_order.clear()
|
self._param_runtime_order.clear()
|
||||||
|
self._step_param_dict.clear()
|
||||||
|
self._param_step_dict.clear()
|
||||||
|
self._preop_step = 0
|
||||||
|
|
|
@ -98,10 +98,7 @@ class ParamMemTracerHook(ColoParamOpHook):
|
||||||
self._allocate_params_on_cuda(params)
|
self._allocate_params_on_cuda(params)
|
||||||
self.sample_model_data(params)
|
self.sample_model_data(params)
|
||||||
self.mem_monitor.start()
|
self.mem_monitor.start()
|
||||||
|
self._memstats.increase_preop_step(params)
|
||||||
# register the order of visited.
|
|
||||||
for p in params:
|
|
||||||
self._memstats._param_runtime_order.append(p)
|
|
||||||
|
|
||||||
def post_op(self, params):
|
def post_op(self, params):
|
||||||
self._free_cuda_params(params)
|
self._free_cuda_params(params)
|
||||||
|
|
|
@ -45,7 +45,8 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
|
||||||
run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer)
|
run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer)
|
||||||
memstats = runtime_mem_tracer.memstats()
|
memstats = runtime_mem_tracer.memstats()
|
||||||
runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list
|
runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list
|
||||||
print('runtime tracer non model data points: ', len(runtime_tracer_non_model_data))
|
print('runtime tracer: ', runtime_tracer_non_model_data)
|
||||||
|
print([memstats.param_used_timestep(p) for p in model.parameters()])
|
||||||
|
|
||||||
model = GeminiDDP(model, device='cuda', placement_policy=placement_policy, search_range_mb=1, memstats=memstats)
|
model = GeminiDDP(model, device='cuda', placement_policy=placement_policy, search_range_mb=1, memstats=memstats)
|
||||||
zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1)
|
zero_optim = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=1)
|
||||||
|
|
Loading…
Reference in New Issue