|
|
|
@ -11,8 +11,13 @@ from colossalai.zero.gemini.memory_tracer import MemStats, OrderedParamGenerator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
|
|
|
|
|
"""
|
|
|
|
|
"""_filter_exlarge_params
|
|
|
|
|
|
|
|
|
|
Filter those parameters whose size is too large (more than 3x standard deviations) from others.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
model (nn.Module): the model.
|
|
|
|
|
size_dict (Dict[int, List[int]]): the size dict of parameters.
|
|
|
|
|
"""
|
|
|
|
|
agg_size_list = []
|
|
|
|
|
for key in size_dict:
|
|
|
|
@ -33,7 +38,16 @@ def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) ->
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
|
|
|
|
|
"""Get unused byte for a certain chunk size.
|
|
|
|
|
"""_get_unused_byte
|
|
|
|
|
|
|
|
|
|
Get unused byte for a certain chunk size.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
size_list (List[int]): the size list of parameters.
|
|
|
|
|
chunk_size (int): the chunk size.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
int: the unused byte.
|
|
|
|
|
"""
|
|
|
|
|
acc = 0
|
|
|
|
|
left = 0
|
|
|
|
@ -45,7 +59,18 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
|
|
|
|
|
return left + acc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool):
|
|
|
|
|
def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int:
|
|
|
|
|
"""_tensor_numel
|
|
|
|
|
|
|
|
|
|
Get the number of elements of a tensor.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
local_param (ColoParameter): The local parameter.
|
|
|
|
|
strict_ddp_flag (bool): whether to enable the strict ddp mode.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
int: the number of elements.
|
|
|
|
|
"""
|
|
|
|
|
if strict_ddp_flag and type(local_param) is ColoParameter:
|
|
|
|
|
return local_param.numel_global()
|
|
|
|
|
else:
|
|
|
|
@ -61,6 +86,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
param_order (OrderedParamGenerator): the order of param be visied
|
|
|
|
|
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode. Defaults to False.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Dict[int, List[ColoParameter]]: a dict contains the classification results.
|
|
|
|
@ -96,6 +122,8 @@ def search_chunk_configuration(
|
|
|
|
|
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
|
|
|
|
|
"""search_chunk_configuration
|
|
|
|
|
|
|
|
|
|
Search the chunk configuration for a model.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
model (nn.Module): torch module
|
|
|
|
|
search_range_mb (float): searching range in mega byte.
|
|
|
|
|