ColossalAI/tests/test_device/test_device_mesh.py

21 lines
571 B
Python
Raw Normal View History

import torch
2023-06-08 02:18:17 +00:00
from colossalai.device.device_mesh import DeviceMesh
def test_device_mesh():
2023-06-08 02:18:17 +00:00
physical_mesh_id = torch.arange(0, 16)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
2023-06-08 02:18:17 +00:00
assert device_mesh.global_rank_to_local_rank(5) == [1, 1]
assert device_mesh.global_rank_to_local_rank(11) == [2, 3]
assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3]
if __name__ == '__main__':
test_device_mesh()