diff --git a/colossalai/gemini/chunk/manager.py b/colossalai/gemini/chunk/manager.py index e73c59b25..30ac4d354 100644 --- a/colossalai/gemini/chunk/manager.py +++ b/colossalai/gemini/chunk/manager.py @@ -72,6 +72,9 @@ class ChunkManager: if tensor.numel() > chunk_size: chunk_size = tensor.numel() + dp_size = tensor.process_group.dp_world_size() + chunk_size = chunk_size + (-chunk_size % dp_size) + chunk = Chunk( chunk_size=chunk_size, process_group=tensor.process_group, diff --git a/colossalai/gemini/chunk/search_utils.py b/colossalai/gemini/chunk/search_utils.py index 57a708135..fe9650721 100644 --- a/colossalai/gemini/chunk/search_utils.py +++ b/colossalai/gemini/chunk/search_utils.py @@ -119,6 +119,7 @@ def search_chunk_configuration( assert search_range_byte >= 0 params_dict = classify_params_by_dp_degree(param_order, strict_ddp_flag) + size_lcm = np.lcm.reduce(list(params_dict.keys())) config_dict: Dict[int, Dict] = dict() total_param_size = 0 @@ -154,6 +155,8 @@ def search_chunk_configuration( min_chunk_waste = temp_waste best_chunk_size = chunk_size + # the chunk size needs to be divided by each groups sizes + best_chunk_size = best_chunk_size + (-best_chunk_size % size_lcm) for dp_degree in params_dict: if dp_degree in config_dict: continue