mirror of https://github.com/hpcaitech/ColossalAI
[fix] fix communication_map;
parent
8eb6eac225
commit
a7b767b071
|
@ -60,6 +60,14 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
# P2P communication
|
||||
self.comm = PipelineP2PCommunication(stage_manager, overlap_p2p=overlap_p2p)
|
||||
|
||||
# init communication map
|
||||
self.communication_map = {
|
||||
"SEND_FORWARD": self.send_forward,
|
||||
"RECV_FORWARD": self.recv_forward,
|
||||
"SEND_BACKWARD": self.send_backward,
|
||||
"RECV_BACKWARD": self.recv_backward,
|
||||
}
|
||||
|
||||
# init buffer
|
||||
self._free_buffers()
|
||||
|
||||
|
@ -162,14 +170,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
model_chunk_id = self.num_model_chunks - model_chunk_id - 1
|
||||
return model_chunk_id
|
||||
|
||||
def communication_func_map(self, node_type: str):
|
||||
return {
|
||||
"SEND_FORWARD": self.send_forward,
|
||||
"RECV_FORWARD": self.recv_forward,
|
||||
"SEND_BACKWARD": self.send_backward,
|
||||
"RECV_BACKWARD": self.recv_backward,
|
||||
}[node_type]
|
||||
|
||||
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Tuple[Any, List]:
|
||||
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
|
||||
For ZBV.
|
||||
|
@ -718,7 +718,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
|
||||
if scheduled_node.type in {"RECV_FORWARD", "SEND_FORWARD"}:
|
||||
# communication
|
||||
communication_func = self.communication_func_map(scheduled_node.type)
|
||||
communication_func = self.communication_map[scheduled_node.type]
|
||||
communication_func(scheduled_node.chunk)
|
||||
if scheduled_node.type == "F":
|
||||
self.schedule_f(
|
||||
|
@ -770,7 +770,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|||
)
|
||||
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES:
|
||||
# communication
|
||||
communication_func = self.communication_func_map(scheduled_node.type)
|
||||
communication_func = self.communication_map[scheduled_node.type]
|
||||
communication_func(scheduled_node.chunk)
|
||||
|
||||
if scheduled_node.type == "F":
|
||||
|
|
Loading…
Reference in New Issue