|
|
|
@ -45,10 +45,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
num_model_chunks: int,
|
|
|
|
|
num_microbatch: Optional[int] = None,
|
|
|
|
|
microbatch_size: Optional[int] = None,
|
|
|
|
|
enable_metadata_cache: bool = False,
|
|
|
|
|
overlap_p2p: bool = True,
|
|
|
|
|
enable_metadata_cache: bool = True,
|
|
|
|
|
overlap_p2p: bool = False,
|
|
|
|
|
):
|
|
|
|
|
super().__init__(stage_manager)
|
|
|
|
|
# Not support overlap_p2p so far
|
|
|
|
|
# batch info
|
|
|
|
|
self.num_microbatch = num_microbatch
|
|
|
|
|
self.microbatch_size = microbatch_size
|
|
|
|
@ -906,9 +907,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
|
|
|
|
|
model_chunk_id=scheduled_node.chunk,
|
|
|
|
|
optimizer=optimizer,
|
|
|
|
|
)
|
|
|
|
|
for h in self.wait_handles:
|
|
|
|
|
for hh in h:
|
|
|
|
|
hh.wait()
|
|
|
|
|
# print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}")
|
|
|
|
|
# return loss & output
|
|
|
|
|
if outputs is not None:
|
|
|
|
|