diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index 65aa0db5a..7f02ca477 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -36,8 +36,8 @@ def get_model_numel(model: torch.nn.Module) -> Tuple[int, int]: return num_params, num_params_trainable -# Test iter input & multiple microbatch -def test_run_fwd_bwd_iter_input( +# Test manual v_schedule with multiple microbatch +def run_fwd_bwd_iter_input( rank: int, world_size: int, port: int, @@ -474,8 +474,8 @@ def test_run_fwd_bwd_iter_input( assert_close(local_chunk[1].weight.grad, model_base.layers[4].weight.grad) -# T -def test_run_fwd_bwd_with_vschedule( +# Test v_schedule generated by graph with multiple microbatch +def run_fwd_bwd_with_vschedule( rank: int, world_size: int, port: int, @@ -623,7 +623,7 @@ def test_run_fwd_bwd_with_vschedule( @rerun_if_address_is_in_use() def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int): spawn( - test_run_fwd_bwd_with_vschedule, + run_fwd_bwd_with_vschedule, nprocs=4, num_microbatch=num_microbatch, batch_size=batch_size,