mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] refactor pipeline grad ckpt config (#5646)
* [shardformer] refactor pipeline grad ckpt config * [shardformer] refactor pipeline grad ckpt config * [pipeline] fix stage managerpull/5654/head
parent
7ef91606e1
commit
1b387ca9fe
|
@ -983,6 +983,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
custom_policy: Policy = None,
|
||||
pp_style: str = "1f1b",
|
||||
num_model_chunks: int = 1,
|
||||
num_layers_per_stage: Optional[List[int]] = None,
|
||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
|
@ -1056,6 +1057,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
pipeline_axis=self.pp_axis,
|
||||
enable_interleave=pp_style == "interleaved",
|
||||
num_model_chunks=num_model_chunks,
|
||||
num_layers_per_stage=num_layers_per_stage,
|
||||
)
|
||||
|
||||
if pp_style == "interleaved":
|
||||
|
|
|
@ -27,16 +27,18 @@ class PipelineStageManager:
|
|||
pipeline_axis: int,
|
||||
enable_interleave: bool = False,
|
||||
num_model_chunks: int = 1,
|
||||
num_layers_per_stage: Optional[List[int]] = None,
|
||||
) -> None:
|
||||
assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False"
|
||||
|
||||
self.num_layers_per_stage = None
|
||||
|
||||
self.pg_mesh = pg_mesh
|
||||
self.pipeline_axis = pipeline_axis
|
||||
self.prev_rank: Optional[Tuple[int, ...]] = None
|
||||
self.next_rank: Optional[Tuple[int, ...]] = None
|
||||
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
|
||||
if num_layers_per_stage is not None:
|
||||
assert len(num_layers_per_stage) == self.num_stages
|
||||
self.num_layers_per_stage = num_layers_per_stage
|
||||
|
||||
# init prev and next coord
|
||||
coord = self.pg_mesh.coordinate()
|
||||
|
@ -56,6 +58,8 @@ class PipelineStageManager:
|
|||
self.p2p_groups[tuple(ranks_in_group)] = group
|
||||
|
||||
self.is_interleave = enable_interleave
|
||||
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
|
||||
self.num_model_chunks: int = num_model_chunks
|
||||
if enable_interleave:
|
||||
# use circle p2p communication
|
||||
# add the process group of the first rank and the last rank
|
||||
|
@ -64,59 +68,11 @@ class PipelineStageManager:
|
|||
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
|
||||
self.p2p_groups[tuple(ranks_in_group)] = group
|
||||
|
||||
# for interleaved pipeline parallel, each device is responsible for multiple chunk of layers
|
||||
self.num_model_chunks: int = num_model_chunks
|
||||
|
||||
# for shardformer, hold stage indices of model
|
||||
self.stage_indices: List[Tuple[int, int]]
|
||||
# for shardformer, hold model chunk id
|
||||
self.model_chunk_id: Optional[int] = None
|
||||
|
||||
@property
|
||||
def control_distribute_layers(self) -> bool:
|
||||
return self.num_layers_per_stage is not None
|
||||
|
||||
def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: List[int]) -> None:
|
||||
"""Set the distribution configuration.
|
||||
This allows user to customize the number of layers for each stage.
|
||||
|
||||
Args:
|
||||
num_model_layers (int): Number of layers in the model.
|
||||
num_layers_per_stage (List[int]): Number of layers for each stage.
|
||||
"""
|
||||
assert all([0 < num_layers < num_model_layers for num_layers in num_layers_per_stage])
|
||||
assert sum(num_layers_per_stage) == num_model_layers
|
||||
assert len(num_layers_per_stage) == self.num_stages * (self.num_model_chunks if self.is_interleave else 1)
|
||||
self.num_model_layers = num_model_layers
|
||||
self.num_layers_per_stage = num_layers_per_stage
|
||||
|
||||
def distribute_layers(
|
||||
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
|
||||
) -> List[int]:
|
||||
"""Divide layers into stages"""
|
||||
num_stages = self.num_stages if num_stages is None else num_stages
|
||||
num_model_chunks = (
|
||||
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
|
||||
)
|
||||
|
||||
if self.control_distribute_layers:
|
||||
assert num_layers == self.num_model_layers
|
||||
return self.num_layers_per_stage
|
||||
|
||||
else:
|
||||
quotient = num_layers // (num_stages * num_model_chunks)
|
||||
remainder = num_layers % (num_stages * num_model_chunks)
|
||||
|
||||
# calculate the num_layers per stage
|
||||
layers_per_stage = [quotient] * num_stages * num_model_chunks
|
||||
|
||||
# deal with the rest layers
|
||||
if remainder > 0:
|
||||
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
|
||||
for i in range(start_position, start_position + remainder):
|
||||
layers_per_stage[i] += 1
|
||||
return layers_per_stage
|
||||
|
||||
def get_stage_index(
|
||||
self,
|
||||
layers_per_stage: List[int],
|
||||
|
@ -139,9 +95,7 @@ class PipelineStageManager:
|
|||
|
||||
"""
|
||||
stage = self.stage if stage is None else stage
|
||||
num_model_chunks = (
|
||||
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
|
||||
)
|
||||
num_model_chunks = self.num_model_chunks if num_model_chunks is None else num_model_chunks
|
||||
num_stages = self.num_stages if num_stages is None else num_stages
|
||||
|
||||
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
|
||||
|
@ -261,3 +215,25 @@ class PipelineStageManager:
|
|||
self.model_chunk_id = model_chunk_id
|
||||
yield
|
||||
self.model_chunk_id = old_model_chunk_id
|
||||
|
||||
def distribute_layers(
|
||||
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
|
||||
) -> List[int]:
|
||||
if self.num_layers_per_stage is not None:
|
||||
assert sum(self.num_layers_per_stage) == num_layers
|
||||
return self.num_layers_per_stage
|
||||
|
||||
num_stages = self.num_stages if num_stages is None else num_stages
|
||||
num_model_chunks = self.num_model_chunks if num_model_chunks is None else num_model_chunks
|
||||
quotient = num_layers // (num_stages * num_model_chunks)
|
||||
remainder = num_layers % (num_stages * num_model_chunks)
|
||||
|
||||
# calculate the num_layers per stage
|
||||
layers_per_stage = [quotient] * num_stages * num_model_chunks
|
||||
|
||||
# deal with the rest layers
|
||||
if remainder > 0:
|
||||
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
|
||||
for i in range(start_position, start_position + remainder):
|
||||
layers_per_stage[i] += 1
|
||||
return layers_per_stage
|
||||
|
|
|
@ -168,8 +168,10 @@ class LlamaPipelineForwards:
|
|||
if shard_config.gradient_checkpoint_config is not None:
|
||||
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
|
||||
stage=stage_manager.stage,
|
||||
num_stages=stage_manager.num_stages,
|
||||
num_layers=end_idx - start_idx,
|
||||
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
)
|
||||
assert num_ckpt_layers <= end_idx - start_idx
|
||||
|
||||
|
|
|
@ -129,8 +129,10 @@ class MistralForwards:
|
|||
if shard_config.gradient_checkpoint_config is not None:
|
||||
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
|
||||
stage=stage_manager.stage,
|
||||
num_stages=stage_manager.num_stages,
|
||||
num_layers=end_idx - start_idx,
|
||||
model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
|
||||
num_model_chunks=stage_manager.num_model_chunks,
|
||||
)
|
||||
assert num_ckpt_layers <= end_idx - start_idx
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@ class SubModuleReplacementDescription:
|
|||
kwargs (Dict[str, Any]): the dictionary used to pass extra arguments to the `ParallelModule.from_native_module` method.
|
||||
ignore_if_not_exist (bool): if the submodule does not exist, ignore it or raise an exception
|
||||
"""
|
||||
|
||||
suffix: str
|
||||
target_module: Union[ParallelModule, BaseLayerNorm]
|
||||
kwargs: Dict[str, Any] = None
|
||||
|
@ -54,6 +55,7 @@ class ModulePolicyDescription:
|
|||
object which specifies the module to be replaced and the target module used to replacement.
|
||||
method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement
|
||||
"""
|
||||
|
||||
attribute_replacement: Dict[str, Any] = None
|
||||
param_replacement: List[Callable] = None
|
||||
sub_module_replacement: List[SubModuleReplacementDescription] = None
|
||||
|
|
|
@ -47,46 +47,33 @@ class PipelineGradientCheckpointConfig(GradientCheckpointConfig):
|
|||
...
|
||||
|
||||
"""
|
||||
num_stages: Optional[int] = None
|
||||
num_model_chunks: Optional[int] = None
|
||||
num_model_layers: Optional[int] = None
|
||||
num_layers_per_stage: Optional[List[int]] = None
|
||||
num_ckpt_layers_per_stage: Optional[List[int]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self._enable_gradient_checkpointing_ratio:
|
||||
if self._enable_customized_ckpt_layers_per_stage:
|
||||
assert all([num_ckpt_layers >= 0 for num_ckpt_layers in self.num_ckpt_layers_per_stage])
|
||||
elif self._enable_gradient_checkpointing_ratio:
|
||||
if not (0 <= self.gradient_checkpointing_ratio <= 1):
|
||||
raise ValueError("gradient_checkpointing_ratio should be in 0% to 100%")
|
||||
|
||||
if self._enable_customized_ckpt_layers_per_stage:
|
||||
assert (
|
||||
self.num_stages is not None and self.num_model_chunks is not None and self.num_model_layers is not None
|
||||
)
|
||||
assert len(self.num_ckpt_layers_per_stage) == self.num_stages * self.num_model_chunks
|
||||
assert all(
|
||||
[0 <= num_ckpt_layers < self.num_model_layers for num_ckpt_layers in self.num_ckpt_layers_per_stage]
|
||||
)
|
||||
self.gradient_checkpointing_ratio = sum(self.num_ckpt_layers_per_stage) / self.num_model_layers
|
||||
|
||||
@property
|
||||
def _enable_gradient_checkpointing_ratio(self) -> bool:
|
||||
return self.gradient_checkpointing_ratio is not None
|
||||
|
||||
@property
|
||||
def _customize_num_layers_per_stage(self) -> bool:
|
||||
return self.num_layers_per_stage is not None and self.num_model_layers is not None
|
||||
|
||||
@property
|
||||
def _enable_customized_ckpt_layers_per_stage(self) -> bool:
|
||||
return self.num_ckpt_layers_per_stage is not None
|
||||
|
||||
def get_num_ckpt_layers(self, stage: int, num_layers: int, model_chunk_id: int = 0) -> int:
|
||||
def get_num_ckpt_layers(
|
||||
self, stage: int, num_stages: int, num_layers: int, model_chunk_id: int = 0, num_model_chunks: int = 1
|
||||
) -> int:
|
||||
if not self._enable_gradient_checkpointing_ratio and not self._enable_customized_ckpt_layers_per_stage:
|
||||
raise RuntimeError("No checkpointed layers information is provided")
|
||||
|
||||
if self._enable_customized_ckpt_layers_per_stage:
|
||||
assert stage <= self.num_stages and model_chunk_id <= self.num_model_chunks
|
||||
num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * self.num_stages]
|
||||
assert len(self.num_ckpt_layers_per_stage) == num_stages * num_model_chunks
|
||||
assert stage <= num_stages and model_chunk_id <= num_model_chunks
|
||||
num_ckpt_layers = self.num_ckpt_layers_per_stage[stage + model_chunk_id * num_stages]
|
||||
assert num_ckpt_layers <= num_layers
|
||||
return num_ckpt_layers
|
||||
else:
|
||||
|
|
|
@ -7,7 +7,7 @@ from torch.distributed import ProcessGroup
|
|||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from .grad_ckpt_config import GradientCheckpointConfig, PipelineGradientCheckpointConfig
|
||||
from .grad_ckpt_config import GradientCheckpointConfig
|
||||
|
||||
__all__ = ["ShardConfig"]
|
||||
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
|
||||
|
@ -105,16 +105,6 @@ class ShardConfig:
|
|||
else:
|
||||
self._sequence_parallel_size = dist.get_world_size(self.sequence_parallel_process_group)
|
||||
|
||||
if (
|
||||
self.pipeline_stage_manager is not None
|
||||
and isinstance(self.gradient_checkpoint_config, PipelineGradientCheckpointConfig)
|
||||
and self.gradient_checkpoint_config._customize_num_layers_per_stage
|
||||
):
|
||||
self.pipeline_stage_manager.set_distribution_config(
|
||||
self.gradient_checkpoint_config.num_model_layers,
|
||||
self.gradient_checkpoint_config.num_layers_per_stage,
|
||||
)
|
||||
|
||||
def _turn_on_all_optimization(self):
|
||||
"""
|
||||
Turn on all optimization.
|
||||
|
|
|
@ -88,16 +88,15 @@ def main():
|
|||
pass
|
||||
|
||||
# ckpt config for LLaMA3-70B on 64 H100 GPUs
|
||||
ckpt_config = (
|
||||
PipelineGradientCheckpointConfig(
|
||||
num_stages=args.pp,
|
||||
num_model_chunks=1,
|
||||
num_model_layers=80,
|
||||
num_layers_per_stage=[19, 20, 20, 21],
|
||||
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
||||
)
|
||||
hybrid_kwargs = (
|
||||
{
|
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
|
||||
num_ckpt_layers_per_stage=[19, 19, 19, 13],
|
||||
),
|
||||
"num_layers_per_stage": [19, 20, 20, 21],
|
||||
}
|
||||
if args.custom_ckpt
|
||||
else None
|
||||
else {}
|
||||
)
|
||||
|
||||
# ==============================
|
||||
|
@ -173,7 +172,7 @@ def main():
|
|||
microbatch_size=args.mbs,
|
||||
precision="bf16",
|
||||
dp_outside=False,
|
||||
gradient_checkpoint_config=ckpt_config,
|
||||
**hybrid_kwargs,
|
||||
)
|
||||
elif args.plugin == "3d_cpu":
|
||||
plugin = HybridParallelPlugin(
|
||||
|
|
|
@ -14,6 +14,7 @@ class _PipelineStageManager(PipelineStageManager):
|
|||
def __init__(self):
|
||||
self.is_interleave = False
|
||||
self.num_layers_per_stage = None
|
||||
self.num_model_chunks = 1
|
||||
|
||||
@property
|
||||
def num_stages(self):
|
||||
|
|
|
@ -14,6 +14,7 @@ class _PipelineStageManager(PipelineStageManager):
|
|||
def __init__(self):
|
||||
self.is_interleave = False
|
||||
self.num_layers_per_stage = None
|
||||
self.num_model_chunks = 1
|
||||
|
||||
@property
|
||||
def num_stages(self):
|
||||
|
|
|
@ -217,9 +217,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
"enable_gradient_checkpointing": True,
|
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
|
||||
num_stages=2, num_model_chunks=1, num_model_layers=8, num_ckpt_layers_per_stage=[4, 0]
|
||||
),
|
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(num_ckpt_layers_per_stage=[4, 0]),
|
||||
},
|
||||
{
|
||||
"tp_size": 4,
|
||||
|
@ -303,9 +301,6 @@ def run_llama_test(test_config):
|
|||
"initial_scale": 1,
|
||||
"enable_gradient_checkpointing": True,
|
||||
"gradient_checkpoint_config": PipelineGradientCheckpointConfig(
|
||||
num_stages=2,
|
||||
num_model_chunks=2,
|
||||
num_model_layers=8,
|
||||
num_ckpt_layers_per_stage=[0, 1, 2, 2],
|
||||
),
|
||||
},
|
||||
|
|
Loading…
Reference in New Issue