[fix] fix wait handle in run_fwd_bwd

pull/6114/head
duanjunwen 2024-11-18 02:50:14 +00:00
parent f48a85e91d
commit 9a21f87ed6
1 changed files with 3 additions and 1 deletions

View File

@ -899,7 +899,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# communication
communication_func = self.communication_map[scheduled_node.type]
wait_handle = communication_func(scheduled_node.chunk)
self.wait_handles.append(wait_handle)
# We wait recv handle in fwd step and bwd step. Here only need to wait for send handle
if scheduled_node.type in {"SEND_FORWARD", "SEND_BACKWARD"}:
self.wait_handles.append(wait_handle)
elif scheduled_node.type == "F":
self.schedule_f(
scheduled_node=scheduled_node,