ColossalAI/tests/test_context/test_2p5d_init.py

122 lines
3.2 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from functools import partial
from pathlib import Path
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
CONFIG_PATH = Path(__file__).parent.joinpath('configs/parallel_2p5d_init.py').absolute()
def check_data_parallel_rank(rank):
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
if rank in list(range(16)):
assert dp_rank == 0
elif rank in list(range(16, 32)):
assert dp_rank == 1
def check_pipeline_parallel_rank(rank):
ppr = gpc.get_local_rank(ParallelMode.PIPELINE)
if rank in list(range(8)):
assert ppr == 0
elif rank in list(range(8, 16)):
assert ppr == 1
elif rank in list(range(16, 24)):
assert ppr == 0
elif rank in list(range(24, 32)):
assert ppr == 1
def check_tensor_parallel_rank(rank):
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
for i in range(8):
ranks = list(range(i, 32, 8))
if rank in ranks:
assert tp_rank == i, f'{rank}:{tp_rank}'
def check_2p5d_parallel_rank(rank):
rp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
cp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
dp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
xp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_XZ)
# check for row parallel group
for i in range(2):
ranks = list(range(i, 32, 2))
if rank in ranks:
assert rp_rank == i
# check for col parallel group
for i in range(2):
ranks = list(range(i * 2, 32, 4))
ranks_plus_ones = [val + 1 for val in ranks]
ranks.extend(ranks_plus_ones)
if rank in ranks:
assert cp_rank == i
# check for depth parallel group
for i in range(2):
ranks = []
for j in range(i * 4, 32, 8):
ranks.extend([j + k for k in range(4)])
if rank in ranks:
assert dp_rank == i
# check for xz parallel group
for i in range(2):
ranks = list(range(i * 2, 32, 8))
ranks_plus_one = [val + 1 for val in ranks]
ranks.extend(ranks_plus_one)
if rank in ranks:
assert xp_rank == i
def init_2halfd(rank, world_size, backend, port, host):
dist_args = dict(
config=CONFIG_PATH,
rank=rank,
world_size=world_size,
backend=backend,
port=port,
host=host,
verbose=True
)
launch(**dist_args)
check_data_parallel_rank(rank)
check_pipeline_parallel_rank(rank)
check_tensor_parallel_rank(rank)
check_2p5d_parallel_rank(rank)
gpc.destroy()
torch.cuda.empty_cache()
@pytest.mark.cpu
def test_2halfd_init():
"""
As no computation or communication is done, we can run this test on CPU.
"""
world_size = 32
test_fn = partial(init_2halfd,
world_size=world_size,
backend='gloo',
port='29901',
host='localhost'
)
mp.spawn(test_fn, nprocs=world_size)
if __name__ == '__main__':
test_2halfd_init()