[shardformer] refactor pipeline grad ckpt config (#5646)

* [shardformer] refactor pipeline grad ckpt config

* [shardformer] refactor pipeline grad ckpt config

* [pipeline] fix stage manager
pull/5654/head
Hongxin Liu 2024-04-25 15:19:30 +08:00 committed by GitHub
parent 7ef91606e1
commit 1b387ca9fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 59 additions and 102 deletions

View File

@ -983,6 +983,7 @@ class HybridParallelPlugin(PipelinePluginBase):
custom_policy: Policy = None, custom_policy: Policy = None,
pp_style: str = "1f1b", pp_style: str = "1f1b",
num_model_chunks: int = 1, num_model_chunks: int = 1,
num_layers_per_stage: Optional[List[int]] = None,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True, enable_metadata_cache: bool = True,
make_vocab_size_divisible_by: int = 64, make_vocab_size_divisible_by: int = 64,
@ -1056,6 +1057,7 @@ class HybridParallelPlugin(PipelinePluginBase):
pipeline_axis=self.pp_axis, pipeline_axis=self.pp_axis,
enable_interleave=pp_style == "interleaved", enable_interleave=pp_style == "interleaved",
num_model_chunks=num_model_chunks, num_model_chunks=num_model_chunks,
num_layers_per_stage=num_layers_per_stage,
) )
if pp_style == "interleaved": if pp_style == "interleaved":

View File

@ -27,16 +27,18 @@ class PipelineStageManager:
pipeline_axis: int, pipeline_axis: int,
enable_interleave: bool = False, enable_interleave: bool = False,
num_model_chunks: int = 1, num_model_chunks: int = 1,
num_layers_per_stage: Optional[List[int]] = None,
) -> None: ) -> None:
assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False" assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False"
self.num_layers_per_stage = None
self.pg_mesh = pg_mesh self.pg_mesh = pg_mesh
self.pipeline_axis = pipeline_axis self.pipeline_axis = pipeline_axis
self.prev_rank: Optional[Tuple[int, ...]] = None self.prev_rank: Optional[Tuple[int, ...]] = None
self.next_rank: Optional[Tuple[int, ...]] = None self.next_rank: Optional[Tuple[int, ...]] = None
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {} self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
if num_layers_per_stage is not None:
assert len(num_layers_per_stage) == self.num_stages
self.num_layers_per_stage = num_layers_per_stage
# init prev and next coord # init prev and next coord
coord = self.pg_mesh.coordinate() coord = self.pg_mesh.coordinate()
@ -56,6 +58,8 @@ class PipelineStageManager:
self.p2p_groups[tuple(ranks_in_group)] = group self.p2p_groups[tuple(ranks_in_group)] = group
self.is_interleave = enable_interleave self.is_interleave = enable_interleave
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks: int = num_model_chunks
if enable_interleave: if enable_interleave:
# use circle p2p communication # use circle p2p communication
# add the process group of the first rank and the last rank # add the process group of the first rank and the last rank
@ -64,59 +68,11 @@ class PipelineStageManager:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group) ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group self.p2p_groups[tuple(ranks_in_group)] = group
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks: int = num_model_chunks
# for shardformer, hold stage indices of model # for shardformer, hold stage indices of model
self.stage_indices: List[Tuple[int, int]] self.stage_indices: List[Tuple[int, int]]
# for shardformer, hold model chunk id # for shardformer, hold model chunk id
self.model_chunk_id: Optional[int] = None self.model_chunk_id: Optional[int] = None
@property
def control_distribute_layers(self) -> bool:
return self.num_layers_per_stage is not None
def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: List[int]) -> None:
"""Set the distribution configuration.
This allows user to customize the number of layers for each stage.
Args:
num_model_layers (int): Number of layers in the model.
num_layers_per_stage (List[int]): Number of layers for each stage.
"""
assert all([0 < num_layers < num_model_layers for num_layers in num_layers_per_stage])
assert sum(num_layers_per_stage) == num_model_layers
assert len(num_layers_per_stage) == self.num_stages * (self.num_model_chunks if self.is_interleave else 1)
self.num_model_layers = num_model_layers
self.num_layers_per_stage = num_layers_per_stage
def distribute_layers(
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
) -> List[int]:
"""Divide layers into stages"""
num_stages = self.num_stages if num_stages is None else num_stages
num_model_chunks = (
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
)
if self.control_distribute_layers:
assert num_layers == self.num_model_layers
return self.num_layers_per_stage
else:
quotient = num_layers // (num_stages * num_model_chunks)
remainder = num_layers % (num_stages * num_model_chunks)
# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages * num_model_chunks
# deal with the rest layers
if remainder > 0:
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage
def get_stage_index( def get_stage_index(
self, self,
layers_per_stage: List[int], layers_per_stage: List[int],
@ -139,9 +95,7 @@ class PipelineStageManager:
""" """
stage = self.stage if stage is None else stage stage = self.stage if stage is None else stage
num_model_chunks = ( num_model_chunks = self.num_model_chunks if num_model_chunks is None else num_model_chunks
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
)
num_stages = self.num_stages if num_stages is None else num_stages num_stages = self.num_stages if num_stages is None else num_stages
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0) num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
@ -261,3 +215,25 @@ class PipelineStageManager:
self.model_chunk_id = model_chunk_id self.model_chunk_id = model_chunk_id
yield yield
self.model_chunk_id = old_model_chunk_id self.model_chunk_id = old_model_chunk_id
def distribute_layers(
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
) -> List[int]:
if self.num_layers_per_stage is not None:
assert sum(self.num_layers_per_stage) == num_layers
return self.num_layers_per_stage
num_stages = self.num_stages if num_stages is None else num_stages
num_model_chunks = self.num_model_chunks if num_model_chunks is None else num_model_chunks
quotient = num_layers // (num_stages * num_model_chunks)
remainder = num_layers % (num_stages * num_model_chunks)
# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages * num_model_chunks
# deal with the rest layers
if remainder > 0:
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage

View File

@ -168,8 +168,10 @@ class LlamaPipelineForwards:
if shard_config.gradient_checkpoint_config is not None: if shard_config.gradient_checkpoint_config is not None:
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
stage=stage_manager.stage, stage=stage_manager.stage,
num_stages=stage_manager.num_stages,
num_layers=end_idx - start_idx, num_layers=end_idx - start_idx,
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
num_model_chunks=stage_manager.num_model_chunks,
) )
assert num_ckpt_layers <= end_idx - start_idx assert num_ckpt_layers <= end_idx - start_idx

View File

@ -129,8 +129,10 @@ class MistralForwards:
if shard_config.gradient_checkpoint_config is not None: if shard_config.gradient_checkpoint_config is not None:
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers( num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
stage=stage_manager.stage, stage=stage_manager.stage,
num_stages=stage_manager.num_stages,
num_layers=end_idx - start_idx, num_layers=end_idx - start_idx,
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0), model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
num_model_chunks=stage_manager.num_model_chunks,
) )
assert num_ckpt_layers <= end_idx - start_idx assert num_ckpt_layers <= end_idx - start_idx

View File

@ -28,6 +28,7 @@ class SubModuleReplacementDescription:
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method. kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
""" """
suffix: str suffix: str
target_module: Union[ParallelModule, BaseLayerNorm] target_module: Union[ParallelModule, BaseLayerNorm]
kwargs: Dict[str, Any] = None kwargs: Dict[str, Any] = None
@ -54,6 +55,7 @@ class ModulePolicyDescription:
object which specifies the module to be replaced and the target module used to replacement. object which specifies the module to be replaced and the target module used to replacement.
method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement
""" """
attribute_replacement: Dict[str, Any] = None attribute_replacement: Dict[str, Any] = None
param_replacement: List[Callable] = None param_replacement: List[Callable] = None
sub_module_replacement: List[SubModuleReplacementDescription] = None sub_module_replacement: List[SubModuleReplacementDescription] = None

View File

@ -47,46 +47,33 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
... ...
""" """
num_stages: Optional[int] = None
num_model_chunks: Optional[int] = None
num_model_layers: Optional[int] = None
num_layers_per_stage: Optional[List[int]] = None
num_ckpt_layers_per_stage: Optional[List[int]] = None num_ckpt_layers_per_stage: Optional[List[int]] = None
def __post_init__(self): def __post_init__(self):
if self._enable_gradient_checkpointing_ratio: if self._enable_customized_ckpt_layers_per_stage:
assert all([num_ckpt_layers >= 0 for num_ckpt_layers in self.num_ckpt_layers_per_stage])
elif self._enable_gradient_checkpointing_ratio:
if not (0 <= self.gradient_checkpointing_ratio <= 1): if not (0 <= self.gradient_checkpointing_ratio <= 1):
raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%") raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%")
if self._enable_customized_ckpt_layers_per_stage:
assert (
self.num_stages is not None and self.num_model_chunks is not None and self.num_model_layers is not None
)
assert len(self.num_ckpt_layers_per_stage) == self.num_stages * self.num_model_chunks
assert all(
[0 <= num_ckpt_layers < self.num_model_layers for num_ckpt_layers in self.num_ckpt_layers_per_stage]
)
self.gradient_checkpointing_ratio = sum(self.num_ckpt_layers_per_stage) / self.num_model_layers
@property @property
def _enable_gradient_checkpointing_ratio(self) -> bool: def _enable_gradient_checkpointing_ratio(self) -> bool:
return self.gradient_checkpointing_ratio is not None return self.gradient_checkpointing_ratio is not None
@property
def _customize_num_layers_per_stage(self) -> bool:
return self.num_layers_per_stage is not None and self.num_model_layers is not None
@property @property
def _enable_customized_ckpt_layers_per_stage(self) -> bool: def _enable_customized_ckpt_layers_per_stage(self) -> bool:
return self.num_ckpt_layers_per_stage is not None return self.num_ckpt_layers_per_stage is not None
def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = 0) -> int: def get_num_ckpt_layers(
self, stage: int, num_stages: int, num_layers: int, model_chunk_id: int = 0, num_model_chunks: int = 1
) -> int:
if not self._enable_gradient_checkpointing_ratio and not self._enable_customized_ckpt_layers_per_stage: if not self._enable_gradient_checkpointing_ratio and not self._enable_customized_ckpt_layers_per_stage:
raise RuntimeError("No checkpointed layers information is provided") raise RuntimeError("No checkpointed layers information is provided")
if self._enable_customized_ckpt_layers_per_stage: if self._enable_customized_ckpt_layers_per_stage:
assert stage <= self.num_stages and model_chunk_id <= self.num_model_chunks assert len(self.num_ckpt_layers_per_stage) == num_stages * num_model_chunks
num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * self.num_stages] assert stage <= num_stages and model_chunk_id <= num_model_chunks
num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * num_stages]
assert num_ckpt_layers <= num_layers assert num_ckpt_layers <= num_layers
return num_ckpt_layers return num_ckpt_layers
else: else:

View File

@ -7,7 +7,7 @@ from torch.distributed import ProcessGroup
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig from .grad_ckpt_config import GradientCheckpointConfig
__all__ = ["ShardConfig"] __all__ = ["ShardConfig"]
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
@ -105,16 +105,6 @@ class ShardConfig:
else: else:
self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group) self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group)
if (
self.pipeline_stage_manager is not None
and isinstance(self.gradient_checkpoint_config, PipelineGradientCheckpointConfig)
and self.gradient_checkpoint_config._customize_num_layers_per_stage
):
self.pipeline_stage_manager.set_distribution_config(
self.gradient_checkpoint_config.num_model_layers,
self.gradient_checkpoint_config.num_layers_per_stage,
)
def _turn_on_all_optimization(self): def _turn_on_all_optimization(self):
""" """
Turn on all optimization. Turn on all optimization.

View File

@ -88,16 +88,15 @@ def main():
pass pass
# ckpt config for LLaMA3-70B on 64 H100 GPUs # ckpt config for LLaMA3-70B on 64 H100 GPUs
ckpt_config = ( hybrid_kwargs = (
PipelineGradientCheckpointConfig( {
num_stages=args.pp, "gradient_checkpoint_config": PipelineGradientCheckpointConfig(
num_model_chunks=1, num_ckpt_layers_per_stage=[19, 19, 19, 13],
num_model_layers=80, ),
num_layers_per_stage=[19, 20, 20, 21], "num_layers_per_stage": [19, 20, 20, 21],
num_ckpt_layers_per_stage=[19, 19, 19, 13], }
)
if args.custom_ckpt if args.custom_ckpt
else None else {}
) )
# ============================== # ==============================
@ -173,7 +172,7 @@ def main():
microbatch_size=args.mbs, microbatch_size=args.mbs,
precision="bf16", precision="bf16",
dp_outside=False, dp_outside=False,
gradient_checkpoint_config=ckpt_config, **hybrid_kwargs,
) )
elif args.plugin == "3d_cpu": elif args.plugin == "3d_cpu":
plugin = HybridParallelPlugin( plugin = HybridParallelPlugin(

View File

@ -14,6 +14,7 @@ class _PipelineStageManager(PipelineStageManager):
def __init__(self): def __init__(self):
self.is_interleave = False self.is_interleave = False
self.num_layers_per_stage = None self.num_layers_per_stage = None
self.num_model_chunks = 1
@property @property
def num_stages(self): def num_stages(self):

View File

@ -14,6 +14,7 @@ class _PipelineStageManager(PipelineStageManager):
def __init__(self): def __init__(self):
self.is_interleave = False self.is_interleave = False
self.num_layers_per_stage = None self.num_layers_per_stage = None
self.num_model_chunks = 1
@property @property
def num_stages(self): def num_stages(self):

View File

@ -217,9 +217,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp32", "precision": "fp32",
"enable_gradient_checkpointing": True, "enable_gradient_checkpointing": True,
"gradient_checkpoint_config": PipelineGradientCheckpointConfig( "gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]),
num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0]
),
}, },
{ {
"tp_size": 4, "tp_size": 4,
@ -303,9 +301,6 @@ def run_llama_test(test_config):
"initial_scale": 1, "initial_scale": 1,
"enable_gradient_checkpointing": True, "enable_gradient_checkpointing": True,
"gradient_checkpoint_config": PipelineGradientCheckpointConfig( "gradient_checkpoint_config": PipelineGradientCheckpointConfig(
num_stages=2,
num_model_chunks=2,
num_model_layers=8,
num_ckpt_layers_per_stage=[0, 1, 2, 2], num_ckpt_layers_per_stage=[0, 1, 2, 2],
), ),
}, },