From 9bc3b6e2202b2b63a76b1967ddfd702f77bbbf1c Mon Sep 17 00:00:00 2001
From: duanjunwen <935724073@qq.com>
Date: Thu, 12 Sep 2024 02:51:46 +0000
Subject: [PATCH] [feat] moehybrid support zerobubble;

---
 .../plugin/moe_hybrid_parallel_plugin.py      | 18 ++++-
 .../test_schedule/test_zerobubble_pp.py       | 70 +++++++++++++++++--
 2 files changed, 81 insertions(+), 7 deletions(-)

diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
index 36973b240..56405ed47 100644
--- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
+++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
@@ -29,6 +29,7 @@ from colossalai.logging import get_dist_logger
 from colossalai.nn.optimizer import cast_to_distributed
 from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
 from colossalai.pipeline.schedule.one_f_one_b import OneForwardOneBackwardSchedule
+from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
 from colossalai.pipeline.stage_manager import PipelineStageManager
 from colossalai.shardformer.policies.base_policy import Policy
 from colossalai.shardformer.shard.grad_ckpt_config import GradientCheckpointConfig
@@ -207,6 +208,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
         custom_policy: Policy = None,
         pp_style: str = "1f1b",
         num_model_chunks: int = 1,
+        scheduler_nodes: List = None,
         num_layers_per_stage: Optional[List[int]] = None,
         gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
         enable_metadata_cache: bool = True,
@@ -282,8 +284,10 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
         self.custom_policy = custom_policy
         assert zero_stage in (0, 1, 2)
         if self.pp_size > 1:
-            assert pp_style in ["1f1b", "interleaved"], "Unsupported pipeline parallelism style"
-            assert pp_style == "interleaved" or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
+            assert pp_style in ["1f1b", "interleaved", "zbv"], "Unsupported pipeline parallelism style"
+            assert (
+                pp_style == "interleaved" or pp_style == "zbv"
+            ) or num_model_chunks == 1, "num_model_chunks must be 1 when using 1f1b"
             assert (
                 num_microbatches is not None or microbatch_size is not None
             ), "num_microbatches or microbatch_size must be specified when using pipeline parallelism"
@@ -293,7 +297,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
             self.stage_manager = PipelineStageManager(
                 self.pg_mesh,
                 pipeline_axis=self.pp_axis,
-                enable_interleave=pp_style == "interleaved",
+                enable_interleave=(pp_style == "interleaved" or pp_style == "zbv"),
                 num_model_chunks=num_model_chunks,
                 num_layers_per_stage=num_layers_per_stage,
             )
@@ -315,6 +319,14 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
                     microbatch_size=microbatch_size,
                     enable_metadata_cache=enable_metadata_cache,
                 )
+            elif pp_style == "zbv":
+                self.schedule = ZeroBubbleVPipeScheduler(
+                    schedule=scheduler_nodes,
+                    stage_manager=self.stage_manager,
+                    num_model_chunks=num_model_chunks,
+                    num_microbatch=num_microbatches,
+                    overlap_p2p=overlap_p2p,
+                )
             else:
                 raise NotImplementedError()
 
diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py
index 825c192d8..1e5cdb3e5 100644
--- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py
+++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py
@@ -14,6 +14,7 @@ from colossalai.logging import disable_existing_loggers
 from colossalai.pipeline.schedule.v_schedule import PipelineGraph, ScheduledNode
 from colossalai.pipeline.schedule.zero_bubble_pp import ZeroBubbleVPipeScheduler
 from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.tensor.d_tensor.api import clear_layout_converter
 from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
 from tests.kit.model_zoo import model_zoo
 
@@ -724,23 +725,83 @@ def run_with_hybridplugin(test_config):
     "test_config",
     [
         {
-            "batch_size": 8,
+            "pp_style": "zbv",
             "tp_size": 1,
             "pp_size": 4,
             "num_microbatches": 4,
             "zero_stage": 1,
             "precision": "bf16",
-            "num_model_chunk": 2,
+            "num_model_chunks": 2,
         },
     ],
 )
 def run_with_moehybridplugin(test_config):
-    model_zoo.get_sub_registry("transformers_bert")
+    sub_model_zoo = model_zoo.get_sub_registry("transformers_bert")
     test_config["use_lazy_init"] = False
     test_config["initial_scale"] = 2**16
     model_list = [
         "transformers_bert",
     ]
+    clear_layout_converter()
+
+    for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+        if name in model_list:
+            # base param
+            model = model_fn()
+            data = data_gen_fn()
+            criterion = loss_fn
+            optimizer = torch.optim.SGD(model.parameters(), momentum=0.1, lr=1e-5)
+
+            output = model(**data)
+            loss = criterion(output)
+            loss.backward()
+            optimizer.step()
+            print(f"output {output}")
+
+            # # pp param
+            # model_pp = deepcopy(model)
+            # data_pp = deepcopy(data)
+            # optimizer_pp = OptimizerWrapper(torch.optim.SGD(model_pp.parameters(), momentum=0.1, lr=1e-5))
+
+            # # init pipeline graph
+            # h, a, s = model.config.hidden_size, model.config.num_attention_heads, 1024
+            # mem_f = 34 * h + 5 * a * s
+            # mem_w = -32 * h
+            # mem_b = -mem_w - mem_f
+            # graph = PipelineGraph(
+            #     n_stage=test_config["pp_size"],
+            #     n_micro=test_config["num_microbatches"],
+            #     f_cost=1,
+            #     b_cost=1,
+            #     w_cost=1,
+            #     c_cost=1,
+            #     f_mem=mem_f,
+            #     b_mem=mem_b,
+            #     w_mem=mem_w,
+            #     # max_mem=mem_f * (p * 2 + m_offset),
+            # )
+
+            # zbv_schedule = graph.get_v_schedule()
+
+            # test_config["scheduler_nodes"] = zbv_schedule
+            # plugin = MoeHybridParallelPlugin(
+            #     **test_config
+            # )
+            # model_pp, optimizer_pp, criterion, data_pp = plugin.configure(
+            #     model = model_pp,
+            #     optimizer = optimizer_pp,
+            #     criterion = criterion,
+            #     dataloader = data_pp,
+            # )
+
+            # output_pp = plugin.execute_pipeline(
+            #     data_iter=iter(data),
+            #     model=model,
+            #     criterion=criterion,
+            #     optimizer=optimizer,
+            #     return_loss = True,
+            #     return_outputs = True,
+            # )
 
 
 # TODO:6) support booster & Hybrid base 4)
@@ -752,8 +813,9 @@ def run_dist(rank, world_size, port):
     disable_existing_loggers()
     colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
     # run_fwd_bwd_iter_input()
-    run_fwd_bwd_vschedule_with_optim()
+    # run_fwd_bwd_vschedule_with_optim()
     # run_with_moehybridplugin()
+    run_with_moehybridplugin()
 
 
 @pytest.mark.dist