|
|
|
@ -760,7 +760,6 @@ def run_fwd_bwd_vschedule_with_optim(test_config):
|
|
|
|
|
], |
|
|
|
|
) |
|
|
|
|
def run_with_booster_moehybridplugin(config: Tuple[int, ...]): |
|
|
|
|
test_config = config |
|
|
|
|
stage, ep_size, pp_size, tp_size, sp_size = config |
|
|
|
|
num_microbatches = pp_size |
|
|
|
|
dist.get_world_size() |
|
|
|
@ -877,7 +876,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
|
|
|
|
parallel_output = sharded_output["loss"] |
|
|
|
|
else: |
|
|
|
|
parallel_output = torch.tensor(12345.0, device="cuda") |
|
|
|
|
print(f"rank {dist.get_rank()} parallel_output {parallel_output}") |
|
|
|
|
# broadcast along pp axis |
|
|
|
|
dist.broadcast(parallel_output, src=dist.get_process_group_ranks(plugin.pp_group)[0], group=plugin.pp_group) |
|
|
|
|
|
|
|
|
@ -905,7 +903,6 @@ def run_with_booster_moehybridplugin(config: Tuple[int, ...]):
|
|
|
|
|
torch_optimizer.step() |
|
|
|
|
torch_optimizer.zero_grad() |
|
|
|
|
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) |
|
|
|
|
print(f"rank {dist.get_rank()} config {test_config} test passed") |
|
|
|
|
clear_layout_converter() |
|
|
|
|
Randomizer.reset_index() |
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
@ -1060,10 +1057,8 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
|
|
|
|
p.grad /= dp_size |
|
|
|
|
torch_optimizer.step() |
|
|
|
|
torch_optimizer.zero_grad() |
|
|
|
|
|
|
|
|
|
print(f"parallel_output {parallel_output}, torch_output_sum {torch_output_sum}") |
|
|
|
|
assert_loose_close(parallel_output, torch_output_sum, dtype=dtype) |
|
|
|
|
print(f"rank {dist.get_rank()} pp_size:{pp_size}, tp_size {tp_size}, sp_size :{sp_size} test passed") |
|
|
|
|
|
|
|
|
|
clear_layout_converter() |
|
|
|
|
Randomizer.reset_index() |
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|