Browse Source

[fix] remove debug info;

pull/6114/head
duanjunwen 4 days ago
parent
commit
dafda0fb70
  1. 7
      tests/test_pipeline/test_schedule/test_zerobubble_pp.py

7
tests/test_pipeline/test_schedule/test_zerobubble_pp.py

@ -760,7 +760,6 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
],
)
def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
test_config = config
stage, ep_size, pp_size, tp_size, sp_size = config
num_microbatches = pp_size
dist.get_world_size()
@ -877,7 +876,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
parallel_output = sharded_output["loss"]
else:
parallel_output = torch.tensor(12345.0, device="cuda")
print(f"rank {dist.get_rank()} parallel_output {parallel_output}")
# broadcast along pp axis
dist.broadcast(parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[0], group=plugin.pp_group)
@ -905,7 +903,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
torch_optimizer.step()
torch_optimizer.zero_grad()
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
print(f"rank {dist.get_rank()} config {test_config} test passed")
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@ -1060,10 +1057,8 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
p.grad /= dp_size
torch_optimizer.step()
torch_optimizer.zero_grad()
print(f"parallel_output {parallel_output}, torch_output_sum {torch_output_sum}")
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype)
print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed")
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()

Loading…
Cancel
Save