diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index 92d214bad..498240878 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -899,7 +899,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # communication communication_func = self.communication_map[scheduled_node.type] wait_handle = communication_func(scheduled_node.chunk) - self.wait_handles.append(wait_handle) + # We wait recv handle in fwd step and bwd step. Here only need to wait for send handle + if scheduled_node.type in {"SEND_FORWARD", "SEND_BACKWARD"}: + self.wait_handles.append(wait_handle) elif scheduled_node.type == "F": self.schedule_f( scheduled_node=scheduled_node,