[fix] fix p2p error in zbv

pull/6114/head
duanjunwen 2 weeks ago
parent b6d5e61809
commit 1bc4dba3a3

@ -45,10 +45,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
num_model_chunks: int, num_model_chunks: int,
num_microbatch: Optional[int] = None, num_microbatch: Optional[int] = None,
microbatch_size: Optional[int] = None, microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = False, enable_metadata_cache: bool = True,
overlap_p2p: bool = True, overlap_p2p: bool = False,
): ):
super().__init__(stage_manager) super().__init__(stage_manager)
# Not support overlap_p2p so far
# batch info # batch info
self.num_microbatch = num_microbatch self.num_microbatch = num_microbatch
self.microbatch_size = microbatch_size self.microbatch_size = microbatch_size
@ -906,9 +907,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id=scheduled_node.chunk, model_chunk_id=scheduled_node.chunk,
optimizer=optimizer, optimizer=optimizer,
) )
for h in self.wait_handles:
for hh in h:
hh.wait()
# print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}") # print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}")
# return loss & output # return loss & output
if outputs is not None: if outputs is not None:

@ -770,13 +770,11 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
@parameterize( @parameterize(
"config", "config",
[ [
# Pass
(1, 2, 1, 1, 2), (1, 2, 1, 1, 2),
(1, 1, 2, 2, 1), (1, 1, 2, 2, 1),
(1, 2, 1, 2, 1), (1, 2, 1, 2, 1),
(1, 2, 2, 1, 1), (1, 2, 2, 1, 1),
# # TODO: adapt mixtral with no TP Linear (1, 1, 4, 1, 1),
(0, 1, 4, 1, 1),
], ],
) )
def run_with_booster_moehybridplugin(config: Tuple[int, ...]): def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
@ -938,7 +936,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
(1, 2, 2, 1), (1, 2, 2, 1),
(1, 2, 1, 2), (1, 2, 1, 2),
(1, 1, 2, 2), (1, 1, 2, 2),
# TODO: support overlap p2p in pp4
(1, 4, 1, 1), (1, 4, 1, 1),
], ],
) )

Loading…
Cancel
Save