diff --git a/colossalai/zero/gemini/chunk/search_utils.py b/colossalai/zero/gemini/chunk/search_utils.py index c4deec8fe..da58e038c 100644 --- a/colossalai/zero/gemini/chunk/search_utils.py +++ b/colossalai/zero/gemini/chunk/search_utils.py @@ -11,8 +11,13 @@ from colossalai.zero.gemini.memory_tracer import MemStats, OrderedParamGenerator def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> None: - """ + """_filter_exlarge_params + Filter those parameters whose size is too large (more than 3x standard deviations) from others. + + Args: + model (nn.Module): the model. + size_dict (Dict[int, List[int]]): the size dict of parameters. """ agg_size_list = [] for key in size_dict: @@ -33,7 +38,16 @@ def _filter_exlarge_params(model: nn.Module, size_dict: Dict[int, List[int]]) -> def _get_unused_byte(size_list: List[int], chunk_size: int) -> int: - """Get unused byte for a certain chunk size. + """_get_unused_byte + + Get unused byte for a certain chunk size. + + Args: + size_list (List[int]): the size list of parameters. + chunk_size (int): the chunk size. + + Returns: + int: the unused byte. """ acc = 0 left = 0 @@ -45,7 +59,18 @@ def _get_unused_byte(size_list: List[int], chunk_size: int) -> int: return left + acc -def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool): +def _tensor_numel(local_param: ColoParameter, strict_ddp_flag: bool) -> int: + """_tensor_numel + + Get the number of elements of a tensor. + + Args: + local_param (ColoParameter): The local parameter. + strict_ddp_flag (bool): whether to enable the strict ddp mode. + + Returns: + int: the number of elements. + """ if strict_ddp_flag and type(local_param) is ColoParameter: return local_param.numel_global() else: @@ -61,6 +86,7 @@ def classify_params_by_dp_degree(param_order: OrderedParamGenerator, Args: param_order (OrderedParamGenerator): the order of param be visied + strict_ddp_flag (bool, optional): whether to enable the strict ddp mode. Defaults to False. Returns: Dict[int, List[ColoParameter]]: a dict contains the classification results. @@ -96,6 +122,8 @@ def search_chunk_configuration( memstas: Optional[MemStats] = None) -> Tuple[Dict, int, int]: """search_chunk_configuration + Search the chunk configuration for a model. + Args: model (nn.Module): torch module search_range_mb (float): searching range in mega byte.