From 9a21f87ed6e161b88378490c026210b4f261c98b Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Mon, 18 Nov 2024 02:50:14 +0000 Subject: [PATCH] [fix] fix wait handle in run_fwd_bwd --- colossalai/pipeline/schedule/zero_bubble_pp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 92d214bad..498240878 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -899,7 +899,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # communication communication_func = self.communication_map[scheduled_node.type] wait_handle = communication_func(scheduled_node.chunk) - self.wait_handles.append(wait_handle) + # We wait recv handle in fwd step and bwd step. Here only need to wait for send handle + if scheduled_node.type in {"SEND_FORWARD", "SEND_BACKWARD"}: + self.wait_handles.append(wait_handle) elif scheduled_node.type == "F": self.schedule_f( scheduled_node=scheduled_node,