mirror of https://github.com/hpcaitech/ColossalAI
parent
d83c633ca6
commit
e094933da1
|
@ -22,6 +22,7 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
|
||||||
2. Customize # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`.
|
2. Customize # ckpt layers assigned to each stage. This takes precedence over `gradient_checkpointing_ratio`.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None.
|
gradient_checkpointing_ratio (Optional[float]): The ratio of gradient checkpointing. It can only be used in pipeline parallelism. Defaults to None.
|
||||||
|
@ -49,6 +50,7 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
|
||||||
num_stages: Optional[int] = None
|
num_stages: Optional[int] = None
|
||||||
num_model_chunks: Optional[int] = None
|
num_model_chunks: Optional[int] = None
|
||||||
num_model_layers: Optional[int] = None
|
num_model_layers: Optional[int] = None
|
||||||
|
num_layers_per_stage: Optional[List[int]] = None
|
||||||
num_ckpt_layers_per_stage: Optional[List[int]] = None
|
num_ckpt_layers_per_stage: Optional[List[int]] = None
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
@ -70,6 +72,10 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
|
||||||
def _enable_gradient_checkpointing_ratio(self) -> bool:
|
def _enable_gradient_checkpointing_ratio(self) -> bool:
|
||||||
return self.gradient_checkpointing_ratio is not None
|
return self.gradient_checkpointing_ratio is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _customize_num_layers_per_stage(self) -> bool:
|
||||||
|
return self.num_layers_per_stage is not None and self.num_model_layers is not None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _enable_customized_ckpt_layers_per_stage(self) -> bool:
|
def _enable_customized_ckpt_layers_per_stage(self) -> bool:
|
||||||
return self.num_ckpt_layers_per_stage is not None
|
return self.num_ckpt_layers_per_stage is not None
|
||||||
|
|
|
@ -7,7 +7,7 @@ from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
|
||||||
from .grad_ckpt_config import GradientCheckpointConfig
|
from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig
|
||||||
|
|
||||||
__all__ = ["ShardConfig"]
|
__all__ = ["ShardConfig"]
|
||||||
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
|
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
|
||||||
|
@ -30,6 +30,7 @@ class ShardConfig:
|
||||||
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
|
gradient_checkpoint_config (Optional[GradientCheckpointConfig]): The gradient checkpoint config. Defaults to None.
|
||||||
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
|
enable_all_optimization (bool): Whether to turn on all optimization tools including 'fused normalization', 'flash attention', 'JIT fused operators', 'sequence parallelism' and 'sequence overlap'. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tensor_parallel_process_group: Optional[ProcessGroup] = None
|
tensor_parallel_process_group: Optional[ProcessGroup] = None
|
||||||
sequence_parallel_process_group: Optional[ProcessGroup] = None
|
sequence_parallel_process_group: Optional[ProcessGroup] = None
|
||||||
pipeline_stage_manager: Optional[PipelineStageManager] = None
|
pipeline_stage_manager: Optional[PipelineStageManager] = None
|
||||||
|
@ -104,6 +105,16 @@ class ShardConfig:
|
||||||
else:
|
else:
|
||||||
self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group)
|
self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.pipeline_stage_manager is not None
|
||||||
|
and isinstance(self.gradient_checkpoint_config, PipelineGradientCheckpointConfig)
|
||||||
|
and self.gradient_checkpoint_config._customize_num_layers_per_stage
|
||||||
|
):
|
||||||
|
self.pipeline_stage_manager.set_distribution_config(
|
||||||
|
self.gradient_checkpoint_config.num_model_layers,
|
||||||
|
self.gradient_checkpoint_config.num_layers_per_stage,
|
||||||
|
)
|
||||||
|
|
||||||
def _turn_on_all_optimization(self):
|
def _turn_on_all_optimization(self):
|
||||||
"""
|
"""
|
||||||
Turn on all optimization.
|
Turn on all optimization.
|
||||||
|
|
Loading…
Reference in New Issue