From 9bc3b6e2202b2b63a76b1967ddfd702f77bbbf1c Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 12 Sep 2024 02:51:46 +0000 Subject: [PATCH] [feat] moehybrid support zerobubble; --- .../plugin/moe_hybrid_parallel_plugin.py | 18 ++++- .../test_schedule/test_zerobubble_pp.py | 70 +++++++++++++++++-- 2 files changed, 81 insertions(+), 7 deletions(-) diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index 36973b240..56405ed47 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -29,6 +29,7 @@ from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import cast_to_distributed from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule +from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig @@ -207,6 +208,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): custom_policy: Policy = None, pp_style: str = "1f1b", num_model_chunks: int = 1, + scheduler_nodes: List = None, num_layers_per_stage: Optional[List[int]] = None, gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None, enable_metadata_cache: bool = True, @@ -282,8 +284,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.custom_policy = custom_policy assert zero_stage in (0, 1, 2) if self.pp_size > 1: - assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style" - assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" + assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style" + assert ( + pp_style == "interleaved" or pp_style == "zbv" + ) or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b" assert ( num_microbatches is not None or microbatch_size is not None ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism" @@ -293,7 +297,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): self.stage_manager = PipelineStageManager( self.pg_mesh, pipeline_axis=self.pp_axis, - enable_interleave=pp_style == "interleaved", + enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"), num_model_chunks=num_model_chunks, num_layers_per_stage=num_layers_per_stage, ) @@ -315,6 +319,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): microbatch_size=microbatch_size, enable_metadata_cache=enable_metadata_cache, ) + elif pp_style == "zbv": + self.schedule = ZeroBubbleVPipeScheduler( + schedule=scheduler_nodes, + stage_manager=self.stage_manager, + num_model_chunks=num_model_chunks, + num_microbatch=num_microbatches, + overlap_p2p=overlap_p2p, + ) else: raise NotImplementedError() diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 825c192d8..1e5cdb3e5 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -14,6 +14,7 @@ from colossalai.logging import disable_existing_loggers from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler from colossalai.pipeline.stage_manager import PipelineStageManager +from colossalai.tensor.d_tensor.api import clear_layout_converter from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from tests.kit.model_zoo import model_zoo @@ -724,23 +725,83 @@ def run_with_hybridplugin(test_config): "test_config", [ { - "batch_size": 8, + "pp_style": "zbv", "tp_size": 1, "pp_size": 4, "num_microbatches": 4, "zero_stage": 1, "precision": "bf16", - "num_model_chunk": 2, + "num_model_chunks": 2, }, ], ) def run_with_moehybridplugin(test_config): - model_zoo.get_sub_registry("transformers_bert") + sub_model_zoo = model_zoo.get_sub_registry("transformers_bert") test_config["use_lazy_init"] = False test_config["initial_scale"] = 2**16 model_list = [ "transformers_bert", ] + clear_layout_converter() + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + if name in model_list: + # base param + model = model_fn() + data = data_gen_fn() + criterion = loss_fn + optimizer = torch.optim.SGD(model.parameters(), momentum=0.1, lr=1e-5) + + output = model(**data) + loss = criterion(output) + loss.backward() + optimizer.step() + print(f"output {output}") + + # # pp param + # model_pp = deepcopy(model) + # data_pp = deepcopy(data) + # optimizer_pp = OptimizerWrapper(torch.optim.SGD(model_pp.parameters(), momentum=0.1, lr=1e-5)) + + # # init pipeline graph + # h, a, s = model.config.hidden_size, model.config.num_attention_heads, 1024 + # mem_f = 34 * h + 5 * a * s + # mem_w = -32 * h + # mem_b = -mem_w - mem_f + # graph = PipelineGraph( + # n_stage=test_config["pp_size"], + # n_micro=test_config["num_microbatches"], + # f_cost=1, + # b_cost=1, + # w_cost=1, + # c_cost=1, + # f_mem=mem_f, + # b_mem=mem_b, + # w_mem=mem_w, + # # max_mem=mem_f * (p * 2 + m_offset), + # ) + + # zbv_schedule = graph.get_v_schedule() + + # test_config["scheduler_nodes"] = zbv_schedule + # plugin = MoeHybridParallelPlugin( + # **test_config + # ) + # model_pp, optimizer_pp, criterion, data_pp = plugin.configure( + # model = model_pp, + # optimizer = optimizer_pp, + # criterion = criterion, + # dataloader = data_pp, + # ) + + # output_pp = plugin.execute_pipeline( + # data_iter=iter(data), + # model=model, + # criterion=criterion, + # optimizer=optimizer, + # return_loss = True, + # return_outputs = True, + # ) # TODO:6) support booster & Hybrid base 4) @@ -752,8 +813,9 @@ def run_dist(rank, world_size, port): disable_existing_loggers() colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") # run_fwd_bwd_iter_input() - run_fwd_bwd_vschedule_with_optim() + # run_fwd_bwd_vschedule_with_optim() # run_with_moehybridplugin() + run_with_moehybridplugin() @pytest.mark.dist