From a7b767b071e78180a290966c5f3fcd43ae8968a5 Mon Sep 17 00:00:00 2001 From: duanjunwen <935724073@qq.com> Date: Fri, 30 Aug 2024 05:56:02 +0000 Subject: [PATCH] [fix] fix communication_map; --- .../pipeline/schedule/zero_bubble_pp.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index ef3977691..41a886a90 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -60,6 +60,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): # P2P communication self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p) + # init communication map + self.communication_map = { + "SEND_FORWARD": self.send_forward, + "RECV_FORWARD": self.recv_forward, + "SEND_BACKWARD": self.send_backward, + "RECV_BACKWARD": self.recv_backward, + } + # init buffer self._free_buffers() @@ -162,14 +170,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): model_chunk_id = self.num_model_chunks - model_chunk_id - 1 return model_chunk_id - def communication_func_map(self, node_type: str): - return { - "SEND_FORWARD": self.send_forward, - "RECV_FORWARD": self.recv_forward, - "SEND_BACKWARD": self.send_backward, - "RECV_BACKWARD": self.recv_backward, - }[node_type] - def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]: """Copy the forward output from the previous stage in pipeline as the input tensor of this stage. For ZBV. @@ -718,7 +718,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}: # communication - communication_func = self.communication_func_map(scheduled_node.type) + communication_func = self.communication_map[scheduled_node.type] communication_func(scheduled_node.chunk) if scheduled_node.type == "F": self.schedule_f( @@ -770,7 +770,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule): ) if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: # communication - communication_func = self.communication_func_map(scheduled_node.type) + communication_func = self.communication_map[scheduled_node.type] communication_func(scheduled_node.chunk) if scheduled_node.type == "F":