import torch from colossalai.device.device_mesh import DeviceMesh def test_device_mesh(): physical_mesh_id = torch.arange(0, 16) mesh_shape = (4, 4) # [[0, 1, 2, 3], # [4, 5, 6, 7], # [8, 9, 10,11], # [12,13,14,15]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) assert device_mesh.global_rank_to_local_rank(5) == [1, 1] assert device_mesh.global_rank_to_local_rank(11) == [2, 3] assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3] if __name__ == '__main__': test_device_mesh()