diff --git a/colossalai/gemini/update/search_utils.py b/colossalai/gemini/update/search_utils.py index 9b95590c6..fdbbf0817 100644 --- a/colossalai/gemini/update/search_utils.py +++ b/colossalai/gemini/update/search_utils.py @@ -48,12 +48,11 @@ def clasify_params(model: nn.Module) -> Dict[int, List[ColoParameter]]: def search_chunk_configuration( model: nn.Module, search_range_mb: int, - search_interval_byte: int, # hidden size is the best value for the interval + search_interval_byte: int, # hidden size is the best value for the interval min_chunk_size_mb: int = 32, - filter_exlarge_params: bool = True -): - search_range_byte = search_range_mb * 1024 ** 2 - min_chunk_size_byte = min_chunk_size_mb * 1024 ** 2 + filter_exlarge_params: bool = True): + search_range_byte = search_range_mb * 1024**2 + min_chunk_size_byte = min_chunk_size_mb * 1024**2 assert search_range_byte % search_interval_byte == 0 params_dict = clasify_params(model)