|
|
|
@ -29,6 +29,7 @@ from colossalai.logging import get_dist_logger
|
|
|
|
|
from colossalai.nn.optimizer import cast_to_distributed
|
|
|
|
|
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
|
|
|
|
|
from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
|
|
|
|
|
from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
|
|
|
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
|
|
|
|
from colossalai.shardformer.policies.base_policy import Policy
|
|
|
|
|
from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig
|
|
|
|
@ -207,6 +208,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|
|
|
|
custom_policy: Policy = None,
|
|
|
|
|
pp_style: str = "1f1b",
|
|
|
|
|
num_model_chunks: int = 1,
|
|
|
|
|
scheduler_nodes: List = None,
|
|
|
|
|
num_layers_per_stage: Optional[List[int]] = None,
|
|
|
|
|
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
|
|
|
|
enable_metadata_cache: bool = True,
|
|
|
|
@ -282,8 +284,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|
|
|
|
self.custom_policy = custom_policy
|
|
|
|
|
assert zero_stage in (0, 1, 2)
|
|
|
|
|
if self.pp_size > 1:
|
|
|
|
|
assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
|
|
|
|
|
assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
|
|
|
|
assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
|
|
|
|
|
assert (
|
|
|
|
|
pp_style == "interleaved" or pp_style == "zbv"
|
|
|
|
|
) or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
|
|
|
|
|
assert (
|
|
|
|
|
num_microbatches is not None or microbatch_size is not None
|
|
|
|
|
), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
|
|
|
|
@ -293,7 +297,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|
|
|
|
self.stage_manager = PipelineStageManager(
|
|
|
|
|
self.pg_mesh,
|
|
|
|
|
pipeline_axis=self.pp_axis,
|
|
|
|
|
enable_interleave=pp_style == "interleaved",
|
|
|
|
|
enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
|
|
|
|
|
num_model_chunks=num_model_chunks,
|
|
|
|
|
num_layers_per_stage=num_layers_per_stage,
|
|
|
|
|
)
|
|
|
|
@ -315,6 +319,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|
|
|
|
microbatch_size=microbatch_size,
|
|
|
|
|
enable_metadata_cache=enable_metadata_cache,
|
|
|
|
|
)
|
|
|
|
|
elif pp_style == "zbv":
|
|
|
|
|
self.schedule = ZeroBubbleVPipeScheduler(
|
|
|
|
|
schedule=scheduler_nodes,
|
|
|
|
|
stage_manager=self.stage_manager,
|
|
|
|
|
num_model_chunks=num_model_chunks,
|
|
|
|
|
num_microbatch=num_microbatches,
|
|
|
|
|
overlap_p2p=overlap_p2p,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|