Add docstr for zero3 chunk search utils (#3572)

pull/3579/head
YH 2 years ago committed by GitHub
parent 9edeadfb24
commit d329c294ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,8 +11,13 @@ from colossalai.zero.gemini.memory_tracer import MemStats, OrderedParamGenerator
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
"""
"""_filter_exlarge_params
Filter those parameters whose size is too large (more than 3x standard deviations) from others.
Args:
model (nn.Module): the model.
size_dict (Dict[int, List[int]]): the size dict of parameters.
"""
agg_size_list = []
for key in size_dict:
@ -33,7 +38,16 @@ def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) ->
def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
"""Get unused byte for a certain chunk size.
"""_get_unused_byte
Get unused byte for a certain chunk size.
Args:
size_list (List[int]): the size list of parameters.
chunk_size (int): the chunk size.
Returns:
int: the unused byte.
"""
acc = 0
left = 0
@ -45,7 +59,18 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
return left + acc
def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool):
def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int:
"""_tensor_numel
Get the number of elements of a tensor.
Args:
local_param (ColoParameter): The local parameter.
strict_ddp_flag (bool): whether to enable the strict ddp mode.
Returns:
int: the number of elements.
"""
if strict_ddp_flag and type(local_param) is ColoParameter:
return local_param.numel_global()
else:
@ -61,6 +86,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
Args:
param_order (OrderedParamGenerator): the order of param be visied
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode. Defaults to False.
Returns:
Dict[int, List[ColoParameter]]: a dict contains the classification results.
@ -96,6 +122,8 @@ def search_chunk_configuration(
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
"""search_chunk_configuration
Search the chunk configuration for a model.
Args:
model (nn.Module): torch module
search_range_mb (float): searching range in mega byte.

Loading…
Cancel
Save