mirror of https://github.com/hpcaitech/ColossalAI
Add docstr for zero3 chunk search utils (#3572)
parent
9edeadfb24
commit
d329c294ec
|
@ -11,8 +11,13 @@ from colossalai.zero.gemini.memory_tracer import MemStats, OrderedParamGenerator
|
|||
|
||||
|
||||
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
|
||||
"""
|
||||
"""_filter_exlarge_params
|
||||
|
||||
Filter those parameters whose size is too large (more than 3x standard deviations) from others.
|
||||
|
||||
Args:
|
||||
model (nn.Module): the model.
|
||||
size_dict (Dict[int, List[int]]): the size dict of parameters.
|
||||
"""
|
||||
agg_size_list = []
|
||||
for key in size_dict:
|
||||
|
@ -33,7 +38,16 @@ def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) ->
|
|||
|
||||
|
||||
def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
|
||||
"""Get unused byte for a certain chunk size.
|
||||
"""_get_unused_byte
|
||||
|
||||
Get unused byte for a certain chunk size.
|
||||
|
||||
Args:
|
||||
size_list (List[int]): the size list of parameters.
|
||||
chunk_size (int): the chunk size.
|
||||
|
||||
Returns:
|
||||
int: the unused byte.
|
||||
"""
|
||||
acc = 0
|
||||
left = 0
|
||||
|
@ -45,7 +59,18 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
|
|||
return left + acc
|
||||
|
||||
|
||||
def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool):
|
||||
def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int:
|
||||
"""_tensor_numel
|
||||
|
||||
Get the number of elements of a tensor.
|
||||
|
||||
Args:
|
||||
local_param (ColoParameter): The local parameter.
|
||||
strict_ddp_flag (bool): whether to enable the strict ddp mode.
|
||||
|
||||
Returns:
|
||||
int: the number of elements.
|
||||
"""
|
||||
if strict_ddp_flag and type(local_param) is ColoParameter:
|
||||
return local_param.numel_global()
|
||||
else:
|
||||
|
@ -61,6 +86,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
|
|||
|
||||
Args:
|
||||
param_order (OrderedParamGenerator): the order of param be visied
|
||||
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode. Defaults to False.
|
||||
|
||||
Returns:
|
||||
Dict[int, List[ColoParameter]]: a dict contains the classification results.
|
||||
|
@ -96,6 +122,8 @@ def search_chunk_configuration(
|
|||
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
|
||||
"""search_chunk_configuration
|
||||
|
||||
Search the chunk configuration for a model.
|
||||
|
||||
Args:
|
||||
model (nn.Module): torch module
|
||||
search_range_mb (float): searching range in mega byte.
|
||||
|
|
Loading…
Reference in New Issue