2022-08-02 11:23:48 +00:00
|
|
|
import torch
|
|
|
|
|
2023-06-08 02:18:17 +00:00
|
|
|
from colossalai.device.device_mesh import DeviceMesh
|
|
|
|
|
2022-08-02 11:23:48 +00:00
|
|
|
|
|
|
|
def test_device_mesh():
|
2023-06-08 02:18:17 +00:00
|
|
|
physical_mesh_id = torch.arange(0, 16)
|
2022-08-02 11:23:48 +00:00
|
|
|
mesh_shape = (4, 4)
|
|
|
|
# [[0, 1, 2, 3],
|
|
|
|
# [4, 5, 6, 7],
|
|
|
|
# [8, 9, 10,11],
|
|
|
|
# [12,13,14,15]]
|
|
|
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
2023-06-08 02:18:17 +00:00
|
|
|
assert device_mesh.global_rank_to_local_rank(5) == [1, 1]
|
|
|
|
assert device_mesh.global_rank_to_local_rank(11) == [2, 3]
|
|
|
|
assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3]
|
2022-08-02 11:23:48 +00:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
test_device_mesh()
|