You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_context/test_hybrid_parallel.py

165 lines
5.9 KiB

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from functools import partial
from pathlib import Path
import pytest
import torch
import torch.multiprocessing as mp
from colossalai import launch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port
from colossalai.context import reset_seeds
from colossalai.global_variables import tensor_parallel_env as tp_env
from colossalai.testing import rerun_if_address_is_in_use
CONFIG_PATH_LIST = list(Path(__file__).parent.glob('configs/*.py'))
def check_data_parallel_rank(rank):
global_world_size = gpc.get_world_size(ParallelMode.GLOBAL)
mp_size = gpc.get_world_size(ParallelMode.MODEL)
num_dp_groups = global_world_size // mp_size
dp_local_rank = gpc.get_local_rank(ParallelMode.DATA)
assert gpc.get_world_size(ParallelMode.DATA) == num_dp_groups
for group_idx in range(num_dp_groups):
ranks_in_dp_group = range(group_idx * mp_size, (group_idx + 1) * mp_size)
if rank in ranks_in_dp_group:
assert dp_local_rank == group_idx
def check_pipeline_parallel_rank(rank):
mp_world_size = gpc.get_world_size(ParallelMode.MODEL)
tp_world_size = gpc.get_world_size(ParallelMode.TENSOR)
num_pipeline_stage = mp_world_size // tp_world_size
pipeline_local_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
for stage_idx in range(num_pipeline_stage):
ranks_in_current_stage = range(stage_idx * tp_world_size, (stage_idx + 1) * tp_world_size)
if rank in ranks_in_current_stage:
assert stage_idx == pipeline_local_rank
def check_model_parallel_rank(rank):
mp_size = gpc.get_world_size(ParallelMode.MODEL)
rank_within_mp_group = rank % mp_size
mp_local_rank = gpc.get_local_rank(ParallelMode.MODEL)
assert rank_within_mp_group == mp_local_rank
def check_tensor_parallel_rank(rank):
if tp_env.mode == '2d':
check_2d_tensor_parallel_rank(rank)
elif tp_env == '2.5d':
check_2p5d_tensor_parallel_rank(rank)
elif tp_env == '3d':
check_3d_tensor_parallel_rank(rank)
def get_tp_info():
global_world_size = gpc.get_world_size(ParallelMode.GLOBAL)
tp_world_size = gpc.get_world_size(ParallelMode.TENSOR)
num_tp_groups = global_world_size // tp_world_size
tp_local_rank = gpc.get_local_rank(ParallelMode.TENSOR)
return tp_local_rank, tp_world_size, num_tp_groups
def check_2d_tensor_parallel_rank(rank):
tp_local_rank, tp_world_size, num_tp_groups = get_tp_info()
for group_id in range(num_tp_groups):
ranks_in_current_tp_group = range(group_id * tp_world_size, (group_id + 1) * tp_world_size)
if rank in ranks_in_current_tp_group:
col_local_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
row_local_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
assert col_local_rank == tp_local_rank // tp_env.summa_dim
assert row_local_rank == tp_local_rank % tp_env.summa_dim
def check_2p5d_tensor_parallel_rank(rank):
tp_local_rank, tp_world_size, num_tp_groups = get_tp_info()
for group_id in range(num_tp_groups):
ranks_in_current_tp_group = range(group_id * tp_world_size, (group_id + 1) * tp_world_size)
if rank in ranks_in_current_tp_group:
rp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
cp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
dp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
xp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_XZ)
assert rp_rank == tp_local_rank % tp_env.summa_dim
assert cp_rank == tp_local_rank // tp_env.tesseract_dim
assert dp_rank == tp_local_rank // (tp_env.summa_dim**2)
assert xp_rank == tp_local_rank // tp_env.summa_dim
def check_3d_tensor_parallel_rank(rank):
tp_local_rank, tp_world_size, num_tp_groups = get_tp_info()
for group_id in range(num_tp_groups):
ranks_in_current_tp_group = range(group_id * tp_world_size, (group_id + 1) * tp_world_size)
if rank in ranks_in_current_tp_group:
ip_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
wp_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
op_rank = gpc.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
assert ip_rank == tp_local_rank % tp_env.depth_3d
assert wp_rank == tp_local_rank // tp_env.depth_3d
assert op_rank == tp_local_rank // (tp_env.depth_3d**2)
def init_context(config_path, rank, world_size, backend, port, host):
dist_args = dict(config=config_path,
rank=rank,
world_size=world_size,
backend=backend,
port=port,
host=host,
verbose=True)
launch(**dist_args)
check_tensor_parallel_rank(rank)
check_data_parallel_rank(rank)
check_pipeline_parallel_rank(rank)
check_model_parallel_rank(rank)
gpc.destroy()
torch.cuda.empty_cache()
def run_dist(rank, world_size, backend, port_list, host):
for config_path, port in zip(CONFIG_PATH_LIST, port_list):
init_context(config_path=config_path, rank=rank, world_size=world_size, backend=backend, port=port, host=host)
reset_seeds()
@pytest.mark.cpu
@rerun_if_address_is_in_use()
def test_context():
"""
As no computation or communication is done, we can run this test on CPU.
"""
world_size = 32
port_list = []
for _ in range(len(CONFIG_PATH_LIST)):
while True:
port = free_port()
if port not in port_list:
port_list.append(port)
break
test_fn = partial(run_dist, world_size=world_size, backend='gloo', port_list=port_list, host='localhost')
mp.spawn(test_fn, nprocs=world_size)
if __name__ == '__main__':
test_context()