From 1bc4dba3a3a8911f05eea8c8eb68cf5807ca75c8 Mon Sep 17 00:00:00 2001
From: duanjunwen <935724073@qq.com>
Date: Thu, 14 Nov 2024 09:40:38 +0000
Subject: [PATCH] [fix] fix p2p error in zbv

---
 colossalai/pipeline/schedule/zero_bubble_pp.py          | 8 +++-----
 tests/test_pipeline/test_schedule/test_zerobubble_pp.py | 5 +----
 2 files changed, 4 insertions(+), 9 deletions(-)

diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py
index f678d7d7f..31e6cfb38 100644
--- a/colossalai/pipeline/schedule/zero_bubble_pp.py
+++ b/colossalai/pipeline/schedule/zero_bubble_pp.py
@@ -45,10 +45,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
         num_model_chunks: int,
         num_microbatch: Optional[int] = None,
         microbatch_size: Optional[int] = None,
-        enable_metadata_cache: bool = False,
-        overlap_p2p: bool = True,
+        enable_metadata_cache: bool = True,
+        overlap_p2p: bool = False,
     ):
         super().__init__(stage_manager)
+        # Not support overlap_p2p so far
         # batch info
         self.num_microbatch = num_microbatch
         self.microbatch_size = microbatch_size
@@ -906,9 +907,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
                     model_chunk_id=scheduled_node.chunk,
                     optimizer=optimizer,
                 )
-        for h in self.wait_handles:
-            for hh in h:
-                hh.wait()
         # print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}")
         # return loss & output
         if outputs is not None:
diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py
index ddb70e5f2..b630d30b1 100644
--- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py
+++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py
@@ -770,13 +770,11 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
 @parameterize(
     "config",
     [
-        # Pass
         (1, 2, 1, 1, 2),
         (1, 1, 2, 2, 1),
         (1, 2, 1, 2, 1),
         (1, 2, 2, 1, 1),
-        # # TODO: adapt mixtral with no TP Linear
-        (0, 1, 4, 1, 1),
+        (1, 1, 4, 1, 1),
     ],
 )
 def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
@@ -938,7 +936,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
         (1, 2, 2, 1),
         (1, 2, 1, 2),
         (1, 1, 2, 2),
-        # TODO: support overlap p2p in pp4
         (1, 4, 1, 1),
     ],
 )