diff --git a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py index e09805dee..65aa0db5a 100644 --- a/tests/test_pipeline/test_schedule/test_zerobubble_pp.py +++ b/tests/test_pipeline/test_schedule/test_zerobubble_pp.py @@ -479,15 +479,20 @@ def test_run_fwd_bwd_with_vschedule( rank: int, world_size: int, port: int, + num_microbatch: int, + batch_size: int, + num_model_chunk: int, ): # init dist colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost") rank = dist.get_rank() pp_size = world_size pg_mesh = ProcessGroupMesh(pp_size) - num_microbatch = 4 + num_microbatch = num_microbatch # stage_manager - stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=pp_size) + stage_manager = PipelineStageManager( + pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk + ) h, a, s = 4096, 32, 1024 mem_f = 34 * h + 5 * a * s @@ -511,7 +516,7 @@ def test_run_fwd_bwd_with_vschedule( scheduler = ZeroBubbleVPipeScheduler( schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ? stage_manager=stage_manager, - num_model_chunks=pp_size, + num_model_chunks=num_model_chunk, num_microbatch=num_microbatch, overlap_p2p=False, ) @@ -520,8 +525,9 @@ def test_run_fwd_bwd_with_vschedule( return (x * x).mean() # init model and input - batch_size = 4 + batch_size = batch_size num_layers = 8 + assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk" in_dim = out_dim = 8 print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};") model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank) @@ -611,16 +617,19 @@ def test_run_fwd_bwd_with_vschedule( @pytest.mark.dist -# @pytest.mark.parametrize("num_microbatch", [4]) -# @pytest.mark.parametrize("batch_size", [4]) -# @pytest.mark.parametrize("num_model_chunk", [2]) +@pytest.mark.parametrize("num_microbatch", [4]) +@pytest.mark.parametrize("batch_size", [4]) +@pytest.mark.parametrize("num_model_chunk", [4]) @rerun_if_address_is_in_use() -def test_pp(): +def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int): spawn( test_run_fwd_bwd_with_vschedule, nprocs=4, + num_microbatch=num_microbatch, + batch_size=batch_size, + num_model_chunk=num_model_chunk, ) if __name__ == "__main__": - test_pp() + test_pp(num_microbatch=4, batch_size=4, num_model_chunk=4)