[shardformer] fix pipeline grad ckpt (#5620)

* [shardformer] fix pipeline grad ckpt
pull/5625/head
Hongxin Liu 2024-04-22 11:25:39 +08:00 committed by GitHub
parent d83c633ca6
commit e094933da1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 1 deletions

View File

@ -22,6 +22,7 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
2. Customize # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`.
"""
"""
Args:
gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None.
@ -49,6 +50,7 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
num_stages: Optional[int] = None
num_model_chunks: Optional[int] = None
num_model_layers: Optional[int] = None
num_layers_per_stage: Optional[List[int]] = None
num_ckpt_layers_per_stage: Optional[List[int]] = None
def __post_init__(self):
@ -70,6 +72,10 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
def _enable_gradient_checkpointing_ratio(self) -> bool:
return self.gradient_checkpointing_ratio is not None
@property
def _customize_num_layers_per_stage(self) -> bool:
return self.num_layers_per_stage is not None and self.num_model_layers is not None
@property
def _enable_customized_ckpt_layers_per_stage(self) -> bool:
return self.num_ckpt_layers_per_stage is not None

View File

@ -7,7 +7,7 @@ from torch.distributed import ProcessGroup
from colossalai.pipeline.stage_manager import PipelineStageManager
from .grad_ckpt_config import GradientCheckpointConfig
from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig
__all__ = ["ShardConfig"]
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
@ -30,6 +30,7 @@ class ShardConfig:
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
"""
tensor_parallel_process_group: Optional[ProcessGroup] = None
sequence_parallel_process_group: Optional[ProcessGroup] = None
pipeline_stage_manager: Optional[PipelineStageManager] = None
@ -104,6 +105,16 @@ class ShardConfig:
else:
self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group)
if (
self.pipeline_stage_manager is not None
and isinstance(self.gradient_checkpoint_config, PipelineGradientCheckpointConfig)
and self.gradient_checkpoint_config._customize_num_layers_per_stage
):
self.pipeline_stage_manager.set_distribution_config(
self.gradient_checkpoint_config.num_model_layers,
self.gradient_checkpoint_config.num_layers_per_stage,
)
def _turn_on_all_optimization(self):
"""
Turn on all optimization.