|
|
|
@ -479,15 +479,20 @@ def test_run_fwd_bwd_with_vschedule(
|
|
|
|
|
rank: int,
|
|
|
|
|
world_size: int,
|
|
|
|
|
port: int,
|
|
|
|
|
num_microbatch: int,
|
|
|
|
|
batch_size: int,
|
|
|
|
|
num_model_chunk: int,
|
|
|
|
|
):
|
|
|
|
|
# init dist
|
|
|
|
|
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
|
|
|
|
rank = dist.get_rank()
|
|
|
|
|
pp_size = world_size
|
|
|
|
|
pg_mesh = ProcessGroupMesh(pp_size)
|
|
|
|
|
num_microbatch = 4
|
|
|
|
|
num_microbatch = num_microbatch
|
|
|
|
|
# stage_manager
|
|
|
|
|
stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=pp_size)
|
|
|
|
|
stage_manager = PipelineStageManager(
|
|
|
|
|
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
h, a, s = 4096, 32, 1024
|
|
|
|
|
mem_f = 34 * h + 5 * a * s
|
|
|
|
@ -511,7 +516,7 @@ def test_run_fwd_bwd_with_vschedule(
|
|
|
|
|
scheduler = ZeroBubbleVPipeScheduler(
|
|
|
|
|
schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ?
|
|
|
|
|
stage_manager=stage_manager,
|
|
|
|
|
num_model_chunks=pp_size,
|
|
|
|
|
num_model_chunks=num_model_chunk,
|
|
|
|
|
num_microbatch=num_microbatch,
|
|
|
|
|
overlap_p2p=False,
|
|
|
|
|
)
|
|
|
|
@ -520,8 +525,9 @@ def test_run_fwd_bwd_with_vschedule(
|
|
|
|
|
return (x * x).mean()
|
|
|
|
|
|
|
|
|
|
# init model and input
|
|
|
|
|
batch_size = 4
|
|
|
|
|
batch_size = batch_size
|
|
|
|
|
num_layers = 8
|
|
|
|
|
assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk"
|
|
|
|
|
in_dim = out_dim = 8
|
|
|
|
|
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
|
|
|
|
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
|
|
|
@ -611,16 +617,19 @@ def test_run_fwd_bwd_with_vschedule(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.dist
|
|
|
|
|
# @pytest.mark.parametrize("num_microbatch", [4])
|
|
|
|
|
# @pytest.mark.parametrize("batch_size", [4])
|
|
|
|
|
# @pytest.mark.parametrize("num_model_chunk", [2])
|
|
|
|
|
@pytest.mark.parametrize("num_microbatch", [4])
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [4])
|
|
|
|
|
@pytest.mark.parametrize("num_model_chunk", [4])
|
|
|
|
|
@rerun_if_address_is_in_use()
|
|
|
|
|
def test_pp():
|
|
|
|
|
def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int):
|
|
|
|
|
spawn(
|
|
|
|
|
test_run_fwd_bwd_with_vschedule,
|
|
|
|
|
nprocs=4,
|
|
|
|
|
num_microbatch=num_microbatch,
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
num_model_chunk=num_model_chunk,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
test_pp()
|
|
|
|
|
test_pp(num_microbatch=4, batch_size=4, num_model_chunk=4)
|
|
|
|
|