Add docstr for zero3 chunk search utils (#3572)

pull/3579/head
YH 2 years ago committed by GitHub
parent 9edeadfb24
commit d329c294ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,8 +11,13 @@ from colossalai.zero.gemini.memory_tracer import MemStats, OrderedParamGenerator
def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None: def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None:
""" """_filter_exlarge_params
Filter those parameters whose size is too large (more than 3x standard deviations) from others. Filter those parameters whose size is too large (more than 3x standard deviations) from others.
Args:
model (nn.Module): the model.
size_dict (Dict[int, List[int]]): the size dict of parameters.
""" """
agg_size_list = [] agg_size_list = []
for key in size_dict: for key in size_dict:
@ -33,7 +38,16 @@ def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) ->
def _get_unused_byte(size_list: List[int], chunk_size: int) -> int: def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
"""Get unused byte for a certain chunk size. """_get_unused_byte
Get unused byte for a certain chunk size.
Args:
size_list (List[int]): the size list of parameters.
chunk_size (int): the chunk size.
Returns:
int: the unused byte.
""" """
acc = 0 acc = 0
left = 0 left = 0
@ -45,7 +59,18 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int:
return left + acc return left + acc
def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool): def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int:
"""_tensor_numel
Get the number of elements of a tensor.
Args:
local_param (ColoParameter): The local parameter.
strict_ddp_flag (bool): whether to enable the strict ddp mode.
Returns:
int: the number of elements.
"""
if strict_ddp_flag and type(local_param) is ColoParameter: if strict_ddp_flag and type(local_param) is ColoParameter:
return local_param.numel_global() return local_param.numel_global()
else: else:
@ -61,6 +86,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator,
Args: Args:
param_order (OrderedParamGenerator): the order of param be visied param_order (OrderedParamGenerator): the order of param be visied
strict_ddp_flag (bool, optional): whether to enable the strict ddp mode. Defaults to False.
Returns: Returns:
Dict[int, List[ColoParameter]]: a dict contains the classification results. Dict[int, List[ColoParameter]]: a dict contains the classification results.
@ -96,6 +122,8 @@ def search_chunk_configuration(
memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]: memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]:
"""search_chunk_configuration """search_chunk_configuration
Search the chunk configuration for a model.
Args: Args:
model (nn.Module): torch module model (nn.Module): torch module
search_range_mb (float): searching range in mega byte. search_range_mb (float): searching range in mega byte.

Loading…
Cancel
Save