import pytest
import torch.distributed as dist

import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.testing import spawn


def check_process_group_mesh_with_gpc():
    from colossalai.legacy.context import ParallelMode
    from colossalai.legacy.core import global_context as gpc

    DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
    pg_mesh = ProcessGroupMesh(1, 2, 2)

    # check world size
    assert gpc.get_world_size(ParallelMode.TENSOR) == pg_mesh.size(
        TP_DIM
    ), f"{gpc.get_world_size(ParallelMode.TENSOR)} != {pg_mesh.size(TP_DIM)}"
    assert gpc.get_world_size(ParallelMode.PIPELINE) == pg_mesh.size(PP_DIM)
    assert gpc.get_world_size(ParallelMode.DATA) == pg_mesh.size(DP_DIM)

    # check locak rank (coordinate)
    assert gpc.get_local_rank(ParallelMode.TENSOR) == pg_mesh.coordinate(
        TP_DIM
    ), f"{gpc.get_local_rank(ParallelMode.TENSOR)} != {pg_mesh.coordinate(TP_DIM)}"
    assert gpc.get_local_rank(ParallelMode.PIPELINE) == pg_mesh.coordinate(PP_DIM)
    assert gpc.get_local_rank(ParallelMode.DATA) == pg_mesh.coordinate(DP_DIM)

    # check ranks in group
    tp_group = pg_mesh.get_group_along_axis(TP_DIM)
    assert gpc.get_ranks_in_group(ParallelMode.TENSOR) == pg_mesh.get_ranks_in_group(tp_group)
    pp_group = pg_mesh.get_group_along_axis(PP_DIM)
    assert gpc.get_ranks_in_group(ParallelMode.PIPELINE) == pg_mesh.get_ranks_in_group(pp_group)
    dp_group = pg_mesh.get_group_along_axis(DP_DIM)
    assert gpc.get_ranks_in_group(ParallelMode.DATA) == pg_mesh.get_ranks_in_group(dp_group)

    # check prev rank
    coord = pg_mesh.coordinate()
    if not gpc.is_first_rank(ParallelMode.TENSOR):
        assert coord[TP_DIM] != 0
        prev_coord = coord[:TP_DIM] + (coord[TP_DIM] - 1,) + coord[TP_DIM + 1 :]
        assert gpc.get_prev_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(prev_coord, pg_mesh.shape)
    if not gpc.is_first_rank(ParallelMode.PIPELINE):
        assert coord[PP_DIM] != 0
        prev_coord = coord[:PP_DIM] + (coord[PP_DIM] - 1,) + coord[PP_DIM + 1 :]
        assert gpc.get_prev_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(prev_coord, pg_mesh.shape)

    # check next rank
    if not gpc.is_last_rank(ParallelMode.TENSOR):
        assert coord[TP_DIM] != pg_mesh.size(TP_DIM) - 1
        next_coord = coord[:TP_DIM] + (coord[TP_DIM] + 1,) + coord[TP_DIM + 1 :]
        assert gpc.get_next_global_rank(ParallelMode.TENSOR) == pg_mesh.ravel(next_coord, pg_mesh.shape)
    if not gpc.is_last_rank(ParallelMode.PIPELINE):
        assert coord[PP_DIM] != pg_mesh.size(PP_DIM) - 1
        next_coord = coord[:PP_DIM] + (coord[PP_DIM] + 1,) + coord[PP_DIM + 1 :]
        assert gpc.get_next_global_rank(ParallelMode.PIPELINE) == pg_mesh.ravel(next_coord, pg_mesh.shape)


def check_process_group_mesh_with_cases():
    DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
    DP_SIZE, PP_SIZE, TP_SIZE = 1, 2, 2
    RANK_TO_COORDINATE = {
        0: (0, 0, 0),
        1: (0, 0, 1),
        2: (0, 1, 0),
        3: (0, 1, 1),
    }
    TP_RANKS_IN_GROUP = {
        0: [0, 1],
        1: [0, 1],
        2: [2, 3],
        3: [2, 3],
    }
    PP_RANKS_IN_GROUP = {
        0: [0, 2],
        1: [1, 3],
        2: [0, 2],
        3: [1, 3],
    }
    DP_RANKS_IN_GROUP = {
        0: [0],
        1: [1],
        2: [2],
        3: [3],
    }

    pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE, TP_SIZE)

    rank = dist.get_rank()
    assert rank == pg_mesh.rank

    # check world size
    assert pg_mesh.size(TP_DIM) == 2
    assert pg_mesh.size(PP_DIM) == 2
    assert pg_mesh.size(DP_DIM) == 1

    # check coordinate
    assert pg_mesh.coordinate(TP_DIM) == RANK_TO_COORDINATE[rank][TP_DIM]
    assert pg_mesh.coordinate(PP_DIM) == RANK_TO_COORDINATE[rank][PP_DIM]
    assert pg_mesh.coordinate(DP_DIM) == RANK_TO_COORDINATE[rank][DP_DIM]

    # check ranks in group
    tp_group = pg_mesh.get_group_along_axis(TP_DIM)
    assert pg_mesh.get_ranks_in_group(tp_group) == TP_RANKS_IN_GROUP[rank]
    pp_group = pg_mesh.get_group_along_axis(PP_DIM)
    assert pg_mesh.get_ranks_in_group(pp_group) == PP_RANKS_IN_GROUP[rank]
    dp_group = pg_mesh.get_group_along_axis(DP_DIM)
    assert pg_mesh.get_ranks_in_group(dp_group) == DP_RANKS_IN_GROUP[rank]

    # check prev rank
    if RANK_TO_COORDINATE[rank][TP_DIM] != 0:
        prev_coord = (
            RANK_TO_COORDINATE[rank][:TP_DIM]
            + (RANK_TO_COORDINATE[rank][TP_DIM] - 1,)
            + RANK_TO_COORDINATE[rank][TP_DIM + 1 :]
        )
        prev_rank = TP_RANKS_IN_GROUP[rank][TP_RANKS_IN_GROUP[rank].index(rank) - 1]
        assert pg_mesh.ravel(prev_coord, pg_mesh.shape) == prev_rank
    if RANK_TO_COORDINATE[rank][PP_DIM] != 0:
        prev_coord = (
            RANK_TO_COORDINATE[rank][:PP_DIM]
            + (RANK_TO_COORDINATE[rank][PP_DIM] - 1,)
            + RANK_TO_COORDINATE[rank][PP_DIM + 1 :]
        )
        prev_rank = PP_RANKS_IN_GROUP[rank][PP_RANKS_IN_GROUP[rank].index(rank) - 1]
        assert pg_mesh.ravel(prev_coord, pg_mesh.shape) == prev_rank

    # check next rank
    if RANK_TO_COORDINATE[rank][TP_DIM] != TP_SIZE - 1:
        next_coord = (
            RANK_TO_COORDINATE[rank][:TP_DIM]
            + (RANK_TO_COORDINATE[rank][TP_DIM] + 1,)
            + RANK_TO_COORDINATE[rank][TP_DIM + 1 :]
        )
        next_rank = TP_RANKS_IN_GROUP[rank][TP_RANKS_IN_GROUP[rank].index(rank) + 1]
        assert pg_mesh.ravel(next_coord, pg_mesh.shape) == next_rank
    if RANK_TO_COORDINATE[rank][PP_DIM] != PP_SIZE - 1:
        next_coord = (
            RANK_TO_COORDINATE[rank][:PP_DIM]
            + (RANK_TO_COORDINATE[rank][PP_DIM] + 1,)
            + RANK_TO_COORDINATE[rank][PP_DIM + 1 :]
        )
        next_rank = PP_RANKS_IN_GROUP[rank][PP_RANKS_IN_GROUP[rank].index(rank) + 1]
        assert pg_mesh.ravel(next_coord, pg_mesh.shape) == next_rank


def run_dist(rank, world_size, port):
    colossalai.launch(
        config=dict(parallel=dict(data=1, pipeline=2, tensor=dict(mode="1d", size=2))),
        rank=rank,
        world_size=world_size,
        port=port,
        host="localhost",
    )
    # TODO(ver217): this function should be removed when gpc is removed
    # check_process_group_mesh_with_gpc()
    check_process_group_mesh_with_cases()


@pytest.mark.dist
def test_process_group_mesh():
    spawn(run_dist, 4)


if __name__ == "__main__":
    test_process_group_mesh()