mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
167 lines
6.2 KiB
167 lines
6.2 KiB
import pytest |
|
import torch.distributed as dist |
|
|
|
import colossalai |
|
from colossalai.cluster import ProcessGroupMesh |
|
from colossalai.testing import spawn |
|
|
|
|
|
def check_process_group_mesh_with_gpc(): |
|
from colossalai.legacy.context import ParallelMode |
|
from colossalai.legacy.core import global_context as gpc |
|
|
|
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 |
|
pg_mesh = ProcessGroupMesh(1, 2, 2) |
|
|
|
# check world size |
|
assert gpc.get_world_size(ParallelMode.TENSOR) == pg_mesh.size( |
|
TP_DIM |
|
), f"{gpc.get_world_size(ParallelMode.TENSOR)} != {pg_mesh.size(TP_DIM)}" |
|
assert gpc.get_world_size(ParallelMode.PIPELINE) == pg_mesh.size(PP_DIM) |
|
assert gpc.get_world_size(ParallelMode.DATA) == pg_mesh.size(DP_DIM) |
|
|
|
# check locak rank (coordinate) |
|
assert gpc.get_local_rank(ParallelMode.TENSOR) == pg_mesh.coordinate( |
|
TP_DIM |
|
), f"{gpc.get_local_rank(ParallelMode.TENSOR)} != {pg_mesh.coordinate(TP_DIM)}" |
|
assert gpc.get_local_rank(ParallelMode.PIPELINE) == pg_mesh.coordinate(PP_DIM) |
|
assert gpc.get_local_rank(ParallelMode.DATA) == pg_mesh.coordinate(DP_DIM) |
|
|
|
# check ranks in group |
|
tp_group = pg_mesh.get_group_along_axis(TP_DIM) |
|
assert gpc.get_ranks_in_group(ParallelMode.TENSOR) == pg_mesh.get_ranks_in_group(tp_group) |
|
pp_group = pg_mesh.get_group_along_axis(PP_DIM) |
|
assert gpc.get_ranks_in_group(ParallelMode.PIPELINE) == pg_mesh.get_ranks_in_group(pp_group) |
|
dp_group = pg_mesh.get_group_along_axis(DP_DIM) |
|
assert gpc.get_ranks_in_group(ParallelMode.DATA) == pg_mesh.get_ranks_in_group(dp_group) |
|
|
|
# check prev rank |
|
coord = pg_mesh.coordinate() |
|
if not gpc.is_first_rank(ParallelMode.TENSOR): |
|
assert coord[TP_DIM] != 0 |
|
prev_coord = coord[:TP_DIM] + (coord[TP_DIM] - 1,) + coord[TP_DIM + 1 :] |
|
assert gpc.get_prev_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(prev_coord, pg_mesh.shape) |
|
if not gpc.is_first_rank(ParallelMode.PIPELINE): |
|
assert coord[PP_DIM] != 0 |
|
prev_coord = coord[:PP_DIM] + (coord[PP_DIM] - 1,) + coord[PP_DIM + 1 :] |
|
assert gpc.get_prev_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(prev_coord, pg_mesh.shape) |
|
|
|
# check next rank |
|
if not gpc.is_last_rank(ParallelMode.TENSOR): |
|
assert coord[TP_DIM] != pg_mesh.size(TP_DIM) - 1 |
|
next_coord = coord[:TP_DIM] + (coord[TP_DIM] + 1,) + coord[TP_DIM + 1 :] |
|
assert gpc.get_next_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(next_coord, pg_mesh.shape) |
|
if not gpc.is_last_rank(ParallelMode.PIPELINE): |
|
assert coord[PP_DIM] != pg_mesh.size(PP_DIM) - 1 |
|
next_coord = coord[:PP_DIM] + (coord[PP_DIM] + 1,) + coord[PP_DIM + 1 :] |
|
assert gpc.get_next_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(next_coord, pg_mesh.shape) |
|
|
|
|
|
def check_process_group_mesh_with_cases(): |
|
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 |
|
DP_SIZE, PP_SIZE, TP_SIZE = 1, 2, 2 |
|
RANK_TO_COORDINATE = { |
|
0: (0, 0, 0), |
|
1: (0, 0, 1), |
|
2: (0, 1, 0), |
|
3: (0, 1, 1), |
|
} |
|
TP_RANKS_IN_GROUP = { |
|
0: [0, 1], |
|
1: [0, 1], |
|
2: [2, 3], |
|
3: [2, 3], |
|
} |
|
PP_RANKS_IN_GROUP = { |
|
0: [0, 2], |
|
1: [1, 3], |
|
2: [0, 2], |
|
3: [1, 3], |
|
} |
|
DP_RANKS_IN_GROUP = { |
|
0: [0], |
|
1: [1], |
|
2: [2], |
|
3: [3], |
|
} |
|
|
|
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE, TP_SIZE) |
|
|
|
rank = dist.get_rank() |
|
assert rank == pg_mesh.rank |
|
|
|
# check world size |
|
assert pg_mesh.size(TP_DIM) == 2 |
|
assert pg_mesh.size(PP_DIM) == 2 |
|
assert pg_mesh.size(DP_DIM) == 1 |
|
|
|
# check coordinate |
|
assert pg_mesh.coordinate(TP_DIM) == RANK_TO_COORDINATE[rank][TP_DIM] |
|
assert pg_mesh.coordinate(PP_DIM) == RANK_TO_COORDINATE[rank][PP_DIM] |
|
assert pg_mesh.coordinate(DP_DIM) == RANK_TO_COORDINATE[rank][DP_DIM] |
|
|
|
# check ranks in group |
|
tp_group = pg_mesh.get_group_along_axis(TP_DIM) |
|
assert pg_mesh.get_ranks_in_group(tp_group) == TP_RANKS_IN_GROUP[rank] |
|
pp_group = pg_mesh.get_group_along_axis(PP_DIM) |
|
assert pg_mesh.get_ranks_in_group(pp_group) == PP_RANKS_IN_GROUP[rank] |
|
dp_group = pg_mesh.get_group_along_axis(DP_DIM) |
|
assert pg_mesh.get_ranks_in_group(dp_group) == DP_RANKS_IN_GROUP[rank] |
|
|
|
# check prev rank |
|
if RANK_TO_COORDINATE[rank][TP_DIM] != 0: |
|
prev_coord = ( |
|
RANK_TO_COORDINATE[rank][:TP_DIM] |
|
+ (RANK_TO_COORDINATE[rank][TP_DIM] - 1,) |
|
+ RANK_TO_COORDINATE[rank][TP_DIM + 1 :] |
|
) |
|
prev_rank = TP_RANKS_IN_GROUP[rank][TP_RANKS_IN_GROUP[rank].index(rank) - 1] |
|
assert pg_mesh.ravel(prev_coord, pg_mesh.shape) == prev_rank |
|
if RANK_TO_COORDINATE[rank][PP_DIM] != 0: |
|
prev_coord = ( |
|
RANK_TO_COORDINATE[rank][:PP_DIM] |
|
+ (RANK_TO_COORDINATE[rank][PP_DIM] - 1,) |
|
+ RANK_TO_COORDINATE[rank][PP_DIM + 1 :] |
|
) |
|
prev_rank = PP_RANKS_IN_GROUP[rank][PP_RANKS_IN_GROUP[rank].index(rank) - 1] |
|
assert pg_mesh.ravel(prev_coord, pg_mesh.shape) == prev_rank |
|
|
|
# check next rank |
|
if RANK_TO_COORDINATE[rank][TP_DIM] != TP_SIZE - 1: |
|
next_coord = ( |
|
RANK_TO_COORDINATE[rank][:TP_DIM] |
|
+ (RANK_TO_COORDINATE[rank][TP_DIM] + 1,) |
|
+ RANK_TO_COORDINATE[rank][TP_DIM + 1 :] |
|
) |
|
next_rank = TP_RANKS_IN_GROUP[rank][TP_RANKS_IN_GROUP[rank].index(rank) + 1] |
|
assert pg_mesh.ravel(next_coord, pg_mesh.shape) == next_rank |
|
if RANK_TO_COORDINATE[rank][PP_DIM] != PP_SIZE - 1: |
|
next_coord = ( |
|
RANK_TO_COORDINATE[rank][:PP_DIM] |
|
+ (RANK_TO_COORDINATE[rank][PP_DIM] + 1,) |
|
+ RANK_TO_COORDINATE[rank][PP_DIM + 1 :] |
|
) |
|
next_rank = PP_RANKS_IN_GROUP[rank][PP_RANKS_IN_GROUP[rank].index(rank) + 1] |
|
assert pg_mesh.ravel(next_coord, pg_mesh.shape) == next_rank |
|
|
|
|
|
def run_dist(rank, world_size, port): |
|
colossalai.launch( |
|
config=dict(parallel=dict(data=1, pipeline=2, tensor=dict(mode="1d", size=2))), |
|
rank=rank, |
|
world_size=world_size, |
|
port=port, |
|
host="localhost", |
|
) |
|
# TODO(ver217): this function should be removed when gpc is removed |
|
# check_process_group_mesh_with_gpc() |
|
check_process_group_mesh_with_cases() |
|
|
|
|
|
@pytest.mark.dist |
|
def test_process_group_mesh(): |
|
spawn(run_dist, 4) |
|
|
|
|
|
if __name__ == "__main__": |
|
test_process_group_mesh()
|
|
|