From 1b387ca9fe2fe7f90459537b0cc19d5bb4edbdc5 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Thu, 25 Apr 2024 15:19:30 +0800 Subject: [PATCH] [shardformer] refactor pipeline grad ckpt config (#5646) * [shardformer] refactor pipeline grad ckpt config * [shardformer] refactor pipeline grad ckpt config * [pipeline] fix stage manager --- .../booster/plugin/hybrid_parallel_plugin.py | 2 + colossalai/pipeline/stage_manager.py | 82 +++++++------------ colossalai/shardformer/modeling/llama.py | 2 + colossalai/shardformer/modeling/mistral.py | 2 + .../shardformer/policies/base_policy.py | 2 + .../shardformer/shard/grad_ckpt_config.py | 31 ++----- colossalai/shardformer/shard/shard_config.py | 12 +-- examples/language/llama/benchmark.py | 19 ++--- .../test_t5_pipeline_utils.py | 1 + .../test_whisper_pipeline_utils.py | 1 + .../test_model/test_shard_llama.py | 7 +- 11 files changed, 59 insertions(+), 102 deletions(-) diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 95fb2def1..5237734f0 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -983,6 +983,7 @@ class HybridParallelPlugin(PipelinePluginBase): custom_policy: Policy = None, pp_style: str = "1f1b", num_model_chunks: int = 1, + num_layers_per_stage: Optional[List[int]] = None, gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, enable_metadata_cache: bool = True, make_vocab_size_divisible_by: int = 64, @@ -1056,6 +1057,7 @@ class HybridParallelPlugin(PipelinePluginBase): pipeline_axis=self.pp_axis, enable_interleave=pp_style == "interleaved", num_model_chunks=num_model_chunks, + num_layers_per_stage=num_layers_per_stage, ) if pp_style == "interleaved": diff --git a/colossalai/pipeline/stage_manager.py b/colossalai/pipeline/stage_manager.py index b0556669b..b7cbd67ab 100644 --- a/colossalai/pipeline/stage_manager.py +++ b/colossalai/pipeline/stage_manager.py @@ -27,16 +27,18 @@ class PipelineStageManager: pipeline_axis: int, enable_interleave: bool = False, num_model_chunks: int = 1, + num_layers_per_stage: Optional[List[int]] = None, ) -> None: assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False" - self.num_layers_per_stage = None - self.pg_mesh = pg_mesh self.pipeline_axis = pipeline_axis self.prev_rank: Optional[Tuple[int, ...]] = None self.next_rank: Optional[Tuple[int, ...]] = None self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {} + if num_layers_per_stage is not None: + assert len(num_layers_per_stage) == self.num_stages + self.num_layers_per_stage = num_layers_per_stage # init prev and next coord coord = self.pg_mesh.coordinate() @@ -56,6 +58,8 @@ class PipelineStageManager: self.p2p_groups[tuple(ranks_in_group)] = group self.is_interleave = enable_interleave + # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers + self.num_model_chunks: int = num_model_chunks if enable_interleave: # use circle p2p communication # add the process group of the first rank and the last rank @@ -64,59 +68,11 @@ class PipelineStageManager: ranks_in_group = self.pg_mesh.get_ranks_in_group(group) self.p2p_groups[tuple(ranks_in_group)] = group - # for interleaved pipeline parallel, each device is responsible for multiple chunk of layers - self.num_model_chunks: int = num_model_chunks - # for shardformer, hold stage indices of model self.stage_indices: List[Tuple[int, int]] # for shardformer, hold model chunk id self.model_chunk_id: Optional[int] = None - @property - def control_distribute_layers(self) -> bool: - return self.num_layers_per_stage is not None - - def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: List[int]) -> None: - """Set the distribution configuration. - This allows user to customize the number of layers for each stage. - - Args: - num_model_layers (int): Number of layers in the model. - num_layers_per_stage (List[int]): Number of layers for each stage. - """ - assert all([0 < num_layers < num_model_layers for num_layers in num_layers_per_stage]) - assert sum(num_layers_per_stage) == num_model_layers - assert len(num_layers_per_stage) == self.num_stages * (self.num_model_chunks if self.is_interleave else 1) - self.num_model_layers = num_model_layers - self.num_layers_per_stage = num_layers_per_stage - - def distribute_layers( - self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None - ) -> List[int]: - """Divide layers into stages""" - num_stages = self.num_stages if num_stages is None else num_stages - num_model_chunks = ( - (self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks - ) - - if self.control_distribute_layers: - assert num_layers == self.num_model_layers - return self.num_layers_per_stage - - else: - quotient = num_layers // (num_stages * num_model_chunks) - remainder = num_layers % (num_stages * num_model_chunks) - - # calculate the num_layers per stage - layers_per_stage = [quotient] * num_stages * num_model_chunks - - # deal with the rest layers - if remainder > 0: - start_position = (num_stages * num_model_chunks) // 2 - remainder // 2 - for i in range(start_position, start_position + remainder): - layers_per_stage[i] += 1 - return layers_per_stage - def get_stage_index( self, layers_per_stage: List[int], @@ -139,9 +95,7 @@ class PipelineStageManager: """ stage = self.stage if stage is None else stage - num_model_chunks = ( - (self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks - ) + num_model_chunks = self.num_model_chunks if num_model_chunks is None else num_model_chunks num_stages = self.num_stages if num_stages is None else num_stages num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) @@ -261,3 +215,25 @@ class PipelineStageManager: self.model_chunk_id = model_chunk_id yield self.model_chunk_id = old_model_chunk_id + + def distribute_layers( + self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None + ) -> List[int]: + if self.num_layers_per_stage is not None: + assert sum(self.num_layers_per_stage) == num_layers + return self.num_layers_per_stage + + num_stages = self.num_stages if num_stages is None else num_stages + num_model_chunks = self.num_model_chunks if num_model_chunks is None else num_model_chunks + quotient = num_layers // (num_stages * num_model_chunks) + remainder = num_layers % (num_stages * num_model_chunks) + + # calculate the num_layers per stage + layers_per_stage = [quotient] * num_stages * num_model_chunks + + # deal with the rest layers + if remainder > 0: + start_position = (num_stages * num_model_chunks) // 2 - remainder // 2 + for i in range(start_position, start_position + remainder): + layers_per_stage[i] += 1 + return layers_per_stage diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 0eb08a043..8a6a7cf17 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -168,8 +168,10 @@ class LlamaPipelineForwards: if shard_config.gradient_checkpoint_config is not None: num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( stage=stage_manager.stage, + num_stages=stage_manager.num_stages, num_layers=end_idx - start_idx, model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), + num_model_chunks=stage_manager.num_model_chunks, ) assert num_ckpt_layers <= end_idx - start_idx diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index ac7845400..d5f00fc9f 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -129,8 +129,10 @@ class MistralForwards: if shard_config.gradient_checkpoint_config is not None: num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( stage=stage_manager.stage, + num_stages=stage_manager.num_stages, num_layers=end_idx - start_idx, model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), + num_model_chunks=stage_manager.num_model_chunks, ) assert num_ckpt_layers <= end_idx - start_idx diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index e976672bb..282cf0464 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -28,6 +28,7 @@ class SubModuleReplacementDescription: kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method. ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception """ + suffix: str target_module: Union[ParallelModule, BaseLayerNorm] kwargs: Dict[str, Any] = None @@ -54,6 +55,7 @@ class ModulePolicyDescription: object which specifies the module to be replaced and the target module used to replacement. method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement """ + attribute_replacement: Dict[str, Any] = None param_replacement: List[Callable] = None sub_module_replacement: List[SubModuleReplacementDescription] = None diff --git a/colossalai/shardformer/shard/grad_ckpt_config.py b/colossalai/shardformer/shard/grad_ckpt_config.py index 9fc857d19..9167da795 100644 --- a/colossalai/shardformer/shard/grad_ckpt_config.py +++ b/colossalai/shardformer/shard/grad_ckpt_config.py @@ -47,46 +47,33 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig): ... """ - num_stages: Optional[int] = None - num_model_chunks: Optional[int] = None - num_model_layers: Optional[int] = None - num_layers_per_stage: Optional[List[int]] = None num_ckpt_layers_per_stage: Optional[List[int]] = None def __post_init__(self): - if self._enable_gradient_checkpointing_ratio: + if self._enable_customized_ckpt_layers_per_stage: + assert all([num_ckpt_layers >= 0 for num_ckpt_layers in self.num_ckpt_layers_per_stage]) + elif self._enable_gradient_checkpointing_ratio: if not (0 <= self.gradient_checkpointing_ratio <= 1): raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%") - if self._enable_customized_ckpt_layers_per_stage: - assert ( - self.num_stages is not None and self.num_model_chunks is not None and self.num_model_layers is not None - ) - assert len(self.num_ckpt_layers_per_stage) == self.num_stages * self.num_model_chunks - assert all( - [0 <= num_ckpt_layers < self.num_model_layers for num_ckpt_layers in self.num_ckpt_layers_per_stage] - ) - self.gradient_checkpointing_ratio = sum(self.num_ckpt_layers_per_stage) / self.num_model_layers - @property def _enable_gradient_checkpointing_ratio(self) -> bool: return self.gradient_checkpointing_ratio is not None - @property - def _customize_num_layers_per_stage(self) -> bool: - return self.num_layers_per_stage is not None and self.num_model_layers is not None - @property def _enable_customized_ckpt_layers_per_stage(self) -> bool: return self.num_ckpt_layers_per_stage is not None - def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = 0) -> int: + def get_num_ckpt_layers( + self, stage: int, num_stages: int, num_layers: int, model_chunk_id: int = 0, num_model_chunks: int = 1 + ) -> int: if not self._enable_gradient_checkpointing_ratio and not self._enable_customized_ckpt_layers_per_stage: raise RuntimeError("No checkpointed layers information is provided") if self._enable_customized_ckpt_layers_per_stage: - assert stage <= self.num_stages and model_chunk_id <= self.num_model_chunks - num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * self.num_stages] + assert len(self.num_ckpt_layers_per_stage) == num_stages * num_model_chunks + assert stage <= num_stages and model_chunk_id <= num_model_chunks + num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * num_stages] assert num_ckpt_layers <= num_layers return num_ckpt_layers else: diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index e20b8e239..98e72d8b3 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -7,7 +7,7 @@ from torch.distributed import ProcessGroup from colossalai.pipeline.stage_manager import PipelineStageManager -from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig +from .grad_ckpt_config import GradientCheckpointConfig __all__ = ["ShardConfig"] SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] @@ -105,16 +105,6 @@ class ShardConfig: else: self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group) - if ( - self.pipeline_stage_manager is not None - and isinstance(self.gradient_checkpoint_config, PipelineGradientCheckpointConfig) - and self.gradient_checkpoint_config._customize_num_layers_per_stage - ): - self.pipeline_stage_manager.set_distribution_config( - self.gradient_checkpoint_config.num_model_layers, - self.gradient_checkpoint_config.num_layers_per_stage, - ) - def _turn_on_all_optimization(self): """ Turn on all optimization. diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index ff94891f5..d26975fc5 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -88,16 +88,15 @@ def main(): pass # ckpt config for LLaMA3-70B on 64 H100 GPUs - ckpt_config = ( - PipelineGradientCheckpointConfig( - num_stages=args.pp, - num_model_chunks=1, - num_model_layers=80, - num_layers_per_stage=[19, 20, 20, 21], - num_ckpt_layers_per_stage=[19, 19, 19, 13], - ) + hybrid_kwargs = ( + { + "gradient_checkpoint_config": PipelineGradientCheckpointConfig( + num_ckpt_layers_per_stage=[19, 19, 19, 13], + ), + "num_layers_per_stage": [19, 20, 20, 21], + } if args.custom_ckpt - else None + else {} ) # ============================== @@ -173,7 +172,7 @@ def main(): microbatch_size=args.mbs, precision="bf16", dp_outside=False, - gradient_checkpoint_config=ckpt_config, + **hybrid_kwargs, ) elif args.plugin == "3d_cpu": plugin = HybridParallelPlugin( diff --git a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py index 1b7b0073f..e2f71ff89 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_t5_pipeline_utils.py @@ -14,6 +14,7 @@ class _PipelineStageManager(PipelineStageManager): def __init__(self): self.is_interleave = False self.num_layers_per_stage = None + self.num_model_chunks = 1 @property def num_stages(self): diff --git a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py index 9f8c1ad32..d39c5ea91 100644 --- a/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py +++ b/tests/test_pipeline/test_pipeline_utils/test_whisper_pipeline_utils.py @@ -14,6 +14,7 @@ class _PipelineStageManager(PipelineStageManager): def __init__(self): self.is_interleave = False self.num_layers_per_stage = None + self.num_model_chunks = 1 @property def num_stages(self): diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 2a10d86c7..394592688 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -217,9 +217,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "use_lazy_init": False, "precision": "fp32", "enable_gradient_checkpointing": True, - "gradient_checkpoint_config": PipelineGradientCheckpointConfig( - num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0] - ), + "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]), }, { "tp_size": 4, @@ -303,9 +301,6 @@ def run_llama_test(test_config): "initial_scale": 1, "enable_gradient_checkpointing": True, "gradient_checkpoint_config": PipelineGradientCheckpointConfig( - num_stages=2, - num_model_chunks=2, - num_model_layers=8, num_ckpt_layers_per_stage=[0, 1, 2, 2], ), },