ColossalAI/tests/test_device/test_init_logical_pg.py

50 lines
1.6 KiB
Python
Raw Normal View History

import torch
from functools import partial
import pytest
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.distributed import ReduceOp
from colossalai.core import global_context as gpc
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.device.device_mesh import DeviceMesh
def check_layer(rank, world_size, port):
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
physical_mesh_id = torch.arange(0, 4)
assert rank == gpc.get_global_rank()
tensor_to_check = torch.tensor([2, 2, 2, 2]).cuda()
mesh_shape = (2, 2)
# [[0, 1,
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
logical_pg_dict = {0: [[0, 2], [1, 3]], 1: [[0, 1], [2, 3]]}
logical_process_groups = device_mesh.process_groups_dict
for mesh_dim, pgs in logical_pg_dict.items():
for index, pg in enumerate(pgs):
if rank in pg:
tensor = torch.ones(4).cuda()
group = logical_process_groups[mesh_dim][index][1]
dist.all_reduce(tensor, op=ReduceOp.SUM, group=group)
assert tensor.equal(tensor_to_check)
gpc.destroy()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_logical_pg():
world_size = 4
run_func = partial(check_layer, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_logical_pg()