From e094933da1d0a574eda105ab6ec0f171d8ddaebb Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Mon, 22 Apr 2024 11:25:39 +0800 Subject: [PATCH] [shardformer] fix pipeline grad ckpt (#5620) * [shardformer] fix pipeline grad ckpt --- colossalai/shardformer/shard/grad_ckpt_config.py | 6 ++++++ colossalai/shardformer/shard/shard_config.py | 13 ++++++++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/colossalai/shardformer/shard/grad_ckpt_config.py b/colossalai/shardformer/shard/grad_ckpt_config.py index 9c6c2b54e..9fc857d19 100644 --- a/colossalai/shardformer/shard/grad_ckpt_config.py +++ b/colossalai/shardformer/shard/grad_ckpt_config.py @@ -22,6 +22,7 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig): 2. Customize # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`. """ + """ Args: gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None. @@ -49,6 +50,7 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig): num_stages: Optional[int] = None num_model_chunks: Optional[int] = None num_model_layers: Optional[int] = None + num_layers_per_stage: Optional[List[int]] = None num_ckpt_layers_per_stage: Optional[List[int]] = None def __post_init__(self): @@ -70,6 +72,10 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig): def _enable_gradient_checkpointing_ratio(self) -> bool: return self.gradient_checkpointing_ratio is not None + @property + def _customize_num_layers_per_stage(self) -> bool: + return self.num_layers_per_stage is not None and self.num_model_layers is not None + @property def _enable_customized_ckpt_layers_per_stage(self) -> bool: return self.num_ckpt_layers_per_stage is not None diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 963732543..597dd9c26 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -7,7 +7,7 @@ from torch.distributed import ProcessGroup from colossalai.pipeline.stage_manager import PipelineStageManager -from .grad_ckpt_config import GradientCheckpointConfig +from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig __all__ = ["ShardConfig"] SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] @@ -30,6 +30,7 @@ class ShardConfig: gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None. enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False. """ + tensor_parallel_process_group: Optional[ProcessGroup] = None sequence_parallel_process_group: Optional[ProcessGroup] = None pipeline_stage_manager: Optional[PipelineStageManager] = None @@ -104,6 +105,16 @@ class ShardConfig: else: self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group) + if ( + self.pipeline_stage_manager is not None + and isinstance(self.gradient_checkpoint_config, PipelineGradientCheckpointConfig) + and self.gradient_checkpoint_config._customize_num_layers_per_stage + ): + self.pipeline_stage_manager.set_distribution_config( + self.gradient_checkpoint_config.num_model_layers, + self.gradient_checkpoint_config.num_layers_per_stage, + ) + def _turn_on_all_optimization(self): """ Turn on all optimization.