From 1bc4dba3a3a8911f05eea8c8eb68cf5807ca75c8 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Thu, 14 Nov 2024 09:40:38 +0000 Subject: [PATCH] [fix] fix p2p error in zbv --- colossalai/pipeline/schedule/zero_bubble_pp.py | 8 +++----- tests/test_pipeline/test_schedule/test_zerobubble_pp.py | 5 +---- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index f678d7d7f..31e6cfb38 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -45,10 +45,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): num_model_chunks: int, num_microbatch: Optional[int] = None, microbatch_size: Optional[int] = None, - enable_metadata_cache: bool = False, - overlap_p2p: bool = True, + enable_metadata_cache: bool = True, + overlap_p2p: bool = False, ): super().__init__(stage_manager) + # Not support overlap_p2p so far # batch info self.num_microbatch = num_microbatch self.microbatch_size = microbatch_size @@ -906,9 +907,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk_id=scheduled_node.chunk, optimizer=optimizer, ) - for h in self.wait_handles: - for hh in h: - hh.wait() # print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}") # return loss & output if outputs is not None: diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index ddb70e5f2..b630d30b1 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -770,13 +770,11 @@ def run_fwd_bwd_vschedule_with_optim(test_config): @parameterize( "config", [ - # Pass (1, 2, 1, 1, 2), (1, 1, 2, 2, 1), (1, 2, 1, 2, 1), (1, 2, 2, 1, 1), - # # TODO: adapt mixtral with no TP Linear - (0, 1, 4, 1, 1), + (1, 1, 4, 1, 1), ], ) def run_with_booster_moehybridplugin(config: Tuple[int, ...]): @@ -938,7 +936,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]): (1, 2, 2, 1), (1, 2, 1, 2), (1, 1, 2, 2), - # TODO: support overlap p2p in pp4 (1, 4, 1, 1), ], )