[feat] fix ci; add assert;

pull/6034/head
duanjunwen 3 months ago
parent 29383b2de0
commit d6e3d7d2a3

@ -479,15 +479,20 @@ def test_run_fwd_bwd_with_vschedule(
rank: int,
world_size: int,
port: int,
num_microbatch: int,
batch_size: int,
num_model_chunk: int,
):
# init dist
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
rank = dist.get_rank()
pp_size = world_size
pg_mesh = ProcessGroupMesh(pp_size)
num_microbatch = 4
num_microbatch = num_microbatch
# stage_manager
stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=pp_size)
stage_manager = PipelineStageManager(
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
)
h, a, s = 4096, 32, 1024
mem_f = 34 * h + 5 * a * s
@ -511,7 +516,7 @@ def test_run_fwd_bwd_with_vschedule(
scheduler = ZeroBubbleVPipeScheduler(
schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ?
stage_manager=stage_manager,
num_model_chunks=pp_size,
num_model_chunks=num_model_chunk,
num_microbatch=num_microbatch,
overlap_p2p=False,
)
@ -520,8 +525,9 @@ def test_run_fwd_bwd_with_vschedule(
return (x * x).mean()
# init model and input
batch_size = 4
batch_size = batch_size
num_layers = 8
assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk"
in_dim = out_dim = 8
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
@ -611,16 +617,19 @@ def test_run_fwd_bwd_with_vschedule(
@pytest.mark.dist
# @pytest.mark.parametrize("num_microbatch", [4])
# @pytest.mark.parametrize("batch_size", [4])
# @pytest.mark.parametrize("num_model_chunk", [2])
@pytest.mark.parametrize("num_microbatch", [4])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("num_model_chunk", [4])
@rerun_if_address_is_in_use()
def test_pp():
def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int):
spawn(
test_run_fwd_bwd_with_vschedule,
nprocs=4,
num_microbatch=num_microbatch,
batch_size=batch_size,
num_model_chunk=num_model_chunk,
)
if __name__ == "__main__":
test_pp()
test_pp(num_microbatch=4, batch_size=4, num_model_chunk=4)

Loading…
Cancel
Save