mirror of https://github.com/hpcaitech/ColossalAI
[feat] fix ci; add assert;
parent
29383b2de0
commit
d6e3d7d2a3
|
@ -479,15 +479,20 @@ def test_run_fwd_bwd_with_vschedule(
|
|||
rank: int,
|
||||
world_size: int,
|
||||
port: int,
|
||||
num_microbatch: int,
|
||||
batch_size: int,
|
||||
num_model_chunk: int,
|
||||
):
|
||||
# init dist
|
||||
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
rank = dist.get_rank()
|
||||
pp_size = world_size
|
||||
pg_mesh = ProcessGroupMesh(pp_size)
|
||||
num_microbatch = 4
|
||||
num_microbatch = num_microbatch
|
||||
# stage_manager
|
||||
stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=pp_size)
|
||||
stage_manager = PipelineStageManager(
|
||||
pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
|
||||
)
|
||||
|
||||
h, a, s = 4096, 32, 1024
|
||||
mem_f = 34 * h + 5 * a * s
|
||||
|
@ -511,7 +516,7 @@ def test_run_fwd_bwd_with_vschedule(
|
|||
scheduler = ZeroBubbleVPipeScheduler(
|
||||
schedule=zbv_schedule[rank], # hint: send whole schedule or local schedule only ?
|
||||
stage_manager=stage_manager,
|
||||
num_model_chunks=pp_size,
|
||||
num_model_chunks=num_model_chunk,
|
||||
num_microbatch=num_microbatch,
|
||||
overlap_p2p=False,
|
||||
)
|
||||
|
@ -520,8 +525,9 @@ def test_run_fwd_bwd_with_vschedule(
|
|||
return (x * x).mean()
|
||||
|
||||
# init model and input
|
||||
batch_size = 4
|
||||
batch_size = batch_size
|
||||
num_layers = 8
|
||||
assert num_layers % num_model_chunk == 0, f"Model with {num_layers} layer can not dist on {num_model_chunk} chunk"
|
||||
in_dim = out_dim = 8
|
||||
print(f"Before init Model: {torch.cuda.memory_allocated()/1024**3 :.3f} GB on device {stage_manager.get_rank()};")
|
||||
model = MlpModel(in_dim=in_dim, out_dim=out_dim, num_layers=num_layers).to(rank)
|
||||
|
@ -611,16 +617,19 @@ def test_run_fwd_bwd_with_vschedule(
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
# @pytest.mark.parametrize("num_microbatch", [4])
|
||||
# @pytest.mark.parametrize("batch_size", [4])
|
||||
# @pytest.mark.parametrize("num_model_chunk", [2])
|
||||
@pytest.mark.parametrize("num_microbatch", [4])
|
||||
@pytest.mark.parametrize("batch_size", [4])
|
||||
@pytest.mark.parametrize("num_model_chunk", [4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_pp():
|
||||
def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int):
|
||||
spawn(
|
||||
test_run_fwd_bwd_with_vschedule,
|
||||
nprocs=4,
|
||||
num_microbatch=num_microbatch,
|
||||
batch_size=batch_size,
|
||||
num_model_chunk=num_model_chunk,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_pp()
|
||||
test_pp(num_microbatch=4, batch_size=4, num_model_chunk=4)
|
||||
|
|
Loading…
Reference in New Issue