[fix] fix p2p error in zbv

pull/6114/head
duanjunwen 2 weeks ago
parent b6d5e61809
commit 1bc4dba3a3

@ -45,10 +45,11 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
num_model_chunks: int,
num_microbatch: Optional[int] = None,
microbatch_size: Optional[int] = None,
enable_metadata_cache: bool = False,
overlap_p2p: bool = True,
enable_metadata_cache: bool = True,
overlap_p2p: bool = False,
):
super().__init__(stage_manager)
# Not support overlap_p2p so far
# batch info
self.num_microbatch = num_microbatch
self.microbatch_size = microbatch_size
@ -906,9 +907,6 @@ class ZeroBubbleVPipeScheduler(PipelineSchedule):
model_chunk_id=scheduled_node.chunk,
optimizer=optimizer,
)
for h in self.wait_handles:
for hh in h:
hh.wait()
# print(f"stage {self.stage_manager.stage}; self.tensor_metadata_recv[0] {self.tensor_metadata_recv[0]}; self.tensor_metadata_recv[1] {self.tensor_metadata_recv[1]}; self.grad_metadata_recv[0] {self.grad_metadata_recv[0]}; self.grad_metadata_recv[1] {self.grad_metadata_recv[1]}")
# return loss & output
if outputs is not None:

@ -770,13 +770,11 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
@parameterize(
"config",
[
# Pass
(1, 2, 1, 1, 2),
(1, 1, 2, 2, 1),
(1, 2, 1, 2, 1),
(1, 2, 2, 1, 1),
# # TODO: adapt mixtral with no TP Linear
(0, 1, 4, 1, 1),
(1, 1, 4, 1, 1),
],
)
def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
@ -938,7 +936,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
(1, 2, 2, 1),
(1, 2, 1, 2),
(1, 1, 2, 2),
# TODO: support overlap p2p in pp4
(1, 4, 1, 1),
],
)

Loading…
Cancel
Save