|
|
|
@ -46,7 +46,7 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
num_microbatch: Optional[int] = None, |
|
|
|
|
microbatch_size: Optional[int] = None, |
|
|
|
|
enable_metadata_cache: bool = True, |
|
|
|
|
overlap_p2p: bool = False, |
|
|
|
|
overlap_p2p: bool = True, |
|
|
|
|
): |
|
|
|
|
super().__init__(stage_manager) |
|
|
|
|
# Not support overlap_p2p so far |
|
|
|
@ -879,12 +879,16 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
schedule = self.schedules[self.stage_manager.stage] # get schedule by stage (rank) |
|
|
|
|
for it in range(len(schedule)): |
|
|
|
|
scheduled_node = schedule[it] |
|
|
|
|
# print(f"stage {self.stage_manager.stage} {scheduled_node.type}") |
|
|
|
|
if scheduled_node.type in AUTO_SCHEDULE_COMMUNICATION_TYPES: |
|
|
|
|
# communication |
|
|
|
|
communication_func = self.communication_map[scheduled_node.type] |
|
|
|
|
wait_handle = communication_func(scheduled_node.chunk) |
|
|
|
|
self.wait_handles.append(wait_handle) |
|
|
|
|
elif scheduled_node.type == "F": |
|
|
|
|
for h in self.wait_handles: |
|
|
|
|
for hh in h: |
|
|
|
|
hh.wait() |
|
|
|
|
self.schedule_f( |
|
|
|
|
scheduled_node=scheduled_node, |
|
|
|
|
model_chunk=model_chunk, |
|
|
|
@ -894,6 +898,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
outputs=outputs, |
|
|
|
|
) |
|
|
|
|
elif scheduled_node.type == "B": |
|
|
|
|
for h in self.wait_handles: |
|
|
|
|
for hh in h: |
|
|
|
|
hh.wait() |
|
|
|
|
self.schedule_b( |
|
|
|
|
scheduled_node=scheduled_node, |
|
|
|
|
model_chunk=model_chunk, |
|
|
|
@ -907,7 +914,9 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
model_chunk_id=scheduled_node.chunk, |
|
|
|
|
optimizer=optimizer, |
|
|
|
|
) |
|
|
|
|
# print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}") |
|
|
|
|
for h in self.wait_handles: |
|
|
|
|
for hh in h: |
|
|
|
|
hh.wait() |
|
|
|
|
# return loss & output |
|
|
|
|
if outputs is not None: |
|
|
|
|
outputs = merge_batch(outputs) |
|
|
|
|