[fix] fix communication_map;

pull/6034/head
duanjunwen 2024-08-30 05:56:02 +00:00
parent 8eb6eac225
commit a7b767b071
1 changed files with 10 additions and 10 deletions

View File

@ -60,6 +60,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
# P2P communication
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
# init communication map
self.communication_map = {
"SEND_FORWARD": self.send_forward,
"RECV_FORWARD": self.recv_forward,
"SEND_BACKWARD": self.send_backward,
"RECV_BACKWARD": self.recv_backward,
}
# init buffer
self._free_buffers()
@ -162,14 +170,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
return model_chunk_id
def communication_func_map(self, node_type: str):
return {
"SEND_FORWARD": self.send_forward,
"RECV_FORWARD": self.recv_forward,
"SEND_BACKWARD": self.send_backward,
"RECV_BACKWARD": self.recv_backward,
}[node_type]
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For ZBV.
@ -718,7 +718,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}:
# communication
communication_func = self.communication_func_map(scheduled_node.type)
communication_func = self.communication_map[scheduled_node.type]
communication_func(scheduled_node.chunk)
if scheduled_node.type == "F":
self.schedule_f(
@ -770,7 +770,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
)
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
# communication
communication_func = self.communication_func_map(scheduled_node.type)
communication_func = self.communication_map[scheduled_node.type]
communication_func(scheduled_node.chunk)
if scheduled_node.type == "F":