2022-08-08 03:15:57 +00:00
|
|
|
import torch
|
2022-10-19 04:53:06 +00:00
|
|
|
|
2022-08-08 03:15:57 +00:00
|
|
|
from colossalai.device.device_mesh import DeviceMesh
|
2022-10-19 04:53:06 +00:00
|
|
|
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
|
2022-08-08 03:15:57 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_sharding_spec():
|
2023-06-09 01:41:27 +00:00
|
|
|
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
|
2022-08-08 03:15:57 +00:00
|
|
|
mesh_shape = (4, 4)
|
|
|
|
# [[0, 1, 2, 3],
|
|
|
|
# [4, 5, 6, 7],
|
|
|
|
# [8, 9, 10,11],
|
|
|
|
# [12,13,14,15]]
|
|
|
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
2022-10-19 04:53:06 +00:00
|
|
|
entire_shape = torch.Size((16, 8, 6))
|
2022-08-08 03:15:57 +00:00
|
|
|
dim_partition_dict = {0: [0, 1]}
|
|
|
|
# DistSpec:
|
|
|
|
# shard_sequence: S01,R,R
|
|
|
|
# device_mesh_shape: (4, 4)
|
|
|
|
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
|
|
|
|
assert str(sharding_spec.sharding_sequence) == "[S01, R, R]"
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
test_sharding_spec()
|