[shardformer] refactor pipeline grad ckpt config (#5646)

* [shardformer] refactor pipeline grad ckpt config

* [shardformer] refactor pipeline grad ckpt config

* [pipeline] fix stage manager
pull/5654/head
Hongxin Liu 2024-04-25 15:19:30 +08:00 committed by GitHub
parent 7ef91606e1
commit 1b387ca9fe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 59 additions and 102 deletions

View File

@ -983,6 +983,7 @@ class HybridParallelPlugin(PipelinePluginBase):
custom_policy: Policy = None,
pp_style: str = "1f1b",
num_model_chunks: int = 1,
num_layers_per_stage: Optional[List[int]] = None,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
make_vocab_size_divisible_by: int = 64,
@ -1056,6 +1057,7 @@ class HybridParallelPlugin(PipelinePluginBase):
pipeline_axis=self.pp_axis,
enable_interleave=pp_style == "interleaved",
num_model_chunks=num_model_chunks,
num_layers_per_stage=num_layers_per_stage,
)
if pp_style == "interleaved":

View File

@ -27,16 +27,18 @@ class PipelineStageManager:
pipeline_axis: int,
enable_interleave: bool = False,
num_model_chunks: int = 1,
num_layers_per_stage: Optional[List[int]] = None,
) -> None:
assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False"
self.num_layers_per_stage = None
self.pg_mesh = pg_mesh
self.pipeline_axis = pipeline_axis
self.prev_rank: Optional[Tuple[int, ...]] = None
self.next_rank: Optional[Tuple[int, ...]] = None
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
if num_layers_per_stage is not None:
assert len(num_layers_per_stage) == self.num_stages
self.num_layers_per_stage = num_layers_per_stage
# init prev and next coord
coord = self.pg_mesh.coordinate()
@ -56,6 +58,8 @@ class PipelineStageManager:
self.p2p_groups[tuple(ranks_in_group)] = group
self.is_interleave = enable_interleave
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks: int = num_model_chunks
if enable_interleave:
# use circle p2p communication
# add the process group of the first rank and the last rank
@ -64,59 +68,11 @@ class PipelineStageManager:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
self.num_model_chunks: int = num_model_chunks
# for shardformer, hold stage indices of model
self.stage_indices: List[Tuple[int, int]]
# for shardformer, hold model chunk id
self.model_chunk_id: Optional[int] = None
@property
def control_distribute_layers(self) -> bool:
return self.num_layers_per_stage is not None
def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: List[int]) -> None:
"""Set the distribution configuration.
This allows user to customize the number of layers for each stage.
Args:
num_model_layers (int): Number of layers in the model.
num_layers_per_stage (List[int]): Number of layers for each stage.
"""
assert all([0 < num_layers < num_model_layers for num_layers in num_layers_per_stage])
assert sum(num_layers_per_stage) == num_model_layers
assert len(num_layers_per_stage) == self.num_stages * (self.num_model_chunks if self.is_interleave else 1)
self.num_model_layers = num_model_layers
self.num_layers_per_stage = num_layers_per_stage
def distribute_layers(
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
) -> List[int]:
"""Divide layers into stages"""
num_stages = self.num_stages if num_stages is None else num_stages
num_model_chunks = (
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
)
if self.control_distribute_layers:
assert num_layers == self.num_model_layers
return self.num_layers_per_stage
else:
quotient = num_layers // (num_stages * num_model_chunks)
remainder = num_layers % (num_stages * num_model_chunks)
# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages * num_model_chunks
# deal with the rest layers
if remainder > 0:
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage
def get_stage_index(
self,
layers_per_stage: List[int],
@ -139,9 +95,7 @@ class PipelineStageManager:
"""
stage = self.stage if stage is None else stage
num_model_chunks = (
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
)
num_model_chunks = self.num_model_chunks if num_model_chunks is None else num_model_chunks
num_stages = self.num_stages if num_stages is None else num_stages
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
@ -261,3 +215,25 @@ class PipelineStageManager:
self.model_chunk_id = model_chunk_id
yield
self.model_chunk_id = old_model_chunk_id
def distribute_layers(
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
) -> List[int]:
if self.num_layers_per_stage is not None:
assert sum(self.num_layers_per_stage) == num_layers
return self.num_layers_per_stage
num_stages = self.num_stages if num_stages is None else num_stages
num_model_chunks = self.num_model_chunks if num_model_chunks is None else num_model_chunks
quotient = num_layers // (num_stages * num_model_chunks)
remainder = num_layers % (num_stages * num_model_chunks)
# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages * num_model_chunks
# deal with the rest layers
if remainder > 0:
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage

View File

@ -168,8 +168,10 @@ class LlamaPipelineForwards:
if shard_config.gradient_checkpoint_config is not None:
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
stage=stage_manager.stage,
num_stages=stage_manager.num_stages,
num_layers=end_idx - start_idx,
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
num_model_chunks=stage_manager.num_model_chunks,
)
assert num_ckpt_layers <= end_idx - start_idx

View File

@ -129,8 +129,10 @@ class MistralForwards:
if shard_config.gradient_checkpoint_config is not None:
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
stage=stage_manager.stage,
num_stages=stage_manager.num_stages,
num_layers=end_idx - start_idx,
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
num_model_chunks=stage_manager.num_model_chunks,
)
assert num_ckpt_layers <= end_idx - start_idx

View File

@ -28,6 +28,7 @@ class SubModuleReplacementDescription:
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
"""
suffix: str
target_module: Union[ParallelModule, BaseLayerNorm]
kwargs: Dict[str, Any] = None
@ -54,6 +55,7 @@ class ModulePolicyDescription:
object which specifies the module to be replaced and the target module used to replacement.
method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement
"""
attribute_replacement: Dict[str, Any] = None
param_replacement: List[Callable] = None
sub_module_replacement: List[SubModuleReplacementDescription] = None

View File

@ -47,46 +47,33 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
...
"""
num_stages: Optional[int] = None
num_model_chunks: Optional[int] = None
num_model_layers: Optional[int] = None
num_layers_per_stage: Optional[List[int]] = None
num_ckpt_layers_per_stage: Optional[List[int]] = None
def __post_init__(self):
if self._enable_gradient_checkpointing_ratio:
if self._enable_customized_ckpt_layers_per_stage:
assert all([num_ckpt_layers >= 0 for num_ckpt_layers in self.num_ckpt_layers_per_stage])
elif self._enable_gradient_checkpointing_ratio:
if not (0 <= self.gradient_checkpointing_ratio <= 1):
raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%")
if self._enable_customized_ckpt_layers_per_stage:
assert (
self.num_stages is not None and self.num_model_chunks is not None and self.num_model_layers is not None
)
assert len(self.num_ckpt_layers_per_stage) == self.num_stages * self.num_model_chunks
assert all(
[0 <= num_ckpt_layers < self.num_model_layers for num_ckpt_layers in self.num_ckpt_layers_per_stage]
)
self.gradient_checkpointing_ratio = sum(self.num_ckpt_layers_per_stage) / self.num_model_layers
@property
def _enable_gradient_checkpointing_ratio(self) -> bool:
return self.gradient_checkpointing_ratio is not None
@property
def _customize_num_layers_per_stage(self) -> bool:
return self.num_layers_per_stage is not None and self.num_model_layers is not None
@property
def _enable_customized_ckpt_layers_per_stage(self) -> bool:
return self.num_ckpt_layers_per_stage is not None
def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = 0) -> int:
def get_num_ckpt_layers(
self, stage: int, num_stages: int, num_layers: int, model_chunk_id: int = 0, num_model_chunks: int = 1
) -> int:
if not self._enable_gradient_checkpointing_ratio and not self._enable_customized_ckpt_layers_per_stage:
raise RuntimeError("No checkpointed layers information is provided")
if self._enable_customized_ckpt_layers_per_stage:
assert stage <= self.num_stages and model_chunk_id <= self.num_model_chunks
num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * self.num_stages]
assert len(self.num_ckpt_layers_per_stage) == num_stages * num_model_chunks
assert stage <= num_stages and model_chunk_id <= num_model_chunks
num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * num_stages]
assert num_ckpt_layers <= num_layers
return num_ckpt_layers
else:

View File

@ -7,7 +7,7 @@ from torch.distributed import ProcessGroup
from colossalai.pipeline.stage_manager import PipelineStageManager
from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig
from .grad_ckpt_config import GradientCheckpointConfig
__all__ = ["ShardConfig"]
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
@ -105,16 +105,6 @@ class ShardConfig:
else:
self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group)
if (
self.pipeline_stage_manager is not None
and isinstance(self.gradient_checkpoint_config, PipelineGradientCheckpointConfig)
and self.gradient_checkpoint_config._customize_num_layers_per_stage
):
self.pipeline_stage_manager.set_distribution_config(
self.gradient_checkpoint_config.num_model_layers,
self.gradient_checkpoint_config.num_layers_per_stage,
)
def _turn_on_all_optimization(self):
"""
Turn on all optimization.

View File

@ -88,16 +88,15 @@ def main():
pass
# ckpt config for LLaMA3-70B on 64 H100 GPUs
ckpt_config = (
PipelineGradientCheckpointConfig(
num_stages=args.pp,
num_model_chunks=1,
num_model_layers=80,
num_layers_per_stage=[19, 20, 20, 21],
num_ckpt_layers_per_stage=[19, 19, 19, 13],
)
hybrid_kwargs = (
{
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
num_ckpt_layers_per_stage=[19, 19, 19, 13],
),
"num_layers_per_stage": [19, 20, 20, 21],
}
if args.custom_ckpt
else None
else {}
)
# ==============================
@ -173,7 +172,7 @@ def main():
microbatch_size=args.mbs,
precision="bf16",
dp_outside=False,
gradient_checkpoint_config=ckpt_config,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":
plugin = HybridParallelPlugin(

View File

@ -14,6 +14,7 @@ class _PipelineStageManager(PipelineStageManager):
def __init__(self):
self.is_interleave = False
self.num_layers_per_stage = None
self.num_model_chunks = 1
@property
def num_stages(self):

View File

@ -14,6 +14,7 @@ class _PipelineStageManager(PipelineStageManager):
def __init__(self):
self.is_interleave = False
self.num_layers_per_stage = None
self.num_model_chunks = 1
@property
def num_stages(self):

View File

@ -217,9 +217,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"use_lazy_init": False,
"precision": "fp32",
"enable_gradient_checkpointing": True,
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0]
),
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]),
},
{
"tp_size": 4,
@ -303,9 +301,6 @@ def run_llama_test(test_config):
"initial_scale": 1,
"enable_gradient_checkpointing": True,
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
num_stages=2,
num_model_chunks=2,
num_model_layers=8,
num_ckpt_layers_per_stage=[0, 1, 2, 2],
),
},