2023-03-07 03:08:11 +00:00
|
|
|
import operator
|
|
|
|
from functools import reduce
|
2023-03-01 08:34:58 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
from colossalai.device.device_mesh import DeviceMesh
|
2023-03-07 03:08:11 +00:00
|
|
|
|
2023-06-22 03:42:11 +00:00
|
|
|
from .misc import DuplicatedShardingDimensionError, ShardingNotDivisibleError
|
2023-03-07 03:08:11 +00:00
|
|
|
from .sharding_spec import ShardingSpec
|
2023-03-01 08:34:58 +00:00
|
|
|
|
|
|
|
|
|
|
|
class Layout:
|
2023-06-09 01:41:27 +00:00
|
|
|
"""Layout of a tensor.
|
2023-03-01 08:34:58 +00:00
|
|
|
|
2023-06-09 01:41:27 +00:00
|
|
|
Attributes:
|
|
|
|
device_mesh: the device mesh to store the tensor distributed.
|
|
|
|
sharding_spec: the sharding specification to describe how the tensor is sharded.
|
2023-06-26 07:50:07 +00:00
|
|
|
global_shape: the entire shape of the global tensor.
|
2023-03-01 08:34:58 +00:00
|
|
|
"""
|
2023-03-07 03:08:11 +00:00
|
|
|
|
2023-06-26 07:50:07 +00:00
|
|
|
def __init__(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size):
|
2023-03-07 03:08:11 +00:00
|
|
|
self.device_mesh = device_mesh
|
|
|
|
self.sharding_spec = sharding_spec
|
2023-06-26 07:50:07 +00:00
|
|
|
self.global_shape = global_shape
|
2023-03-07 03:08:11 +00:00
|
|
|
self._sanity_check()
|
|
|
|
|
|
|
|
def __hash__(self) -> int:
|
2023-09-19 06:20:26 +00:00
|
|
|
return hash(f"{self.sharding_spec}")
|
2023-03-07 03:08:11 +00:00
|
|
|
|
2023-06-09 01:41:27 +00:00
|
|
|
def get_sharded_shape_per_device(self):
|
2023-06-26 07:50:07 +00:00
|
|
|
sharded_shape = list(self.global_shape)
|
2023-03-07 03:08:11 +00:00
|
|
|
for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
|
2023-06-15 10:03:38 +00:00
|
|
|
mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
|
2023-03-07 03:08:11 +00:00
|
|
|
shard_partitions = reduce(operator.mul, mesh_list, 1)
|
2023-09-19 06:20:26 +00:00
|
|
|
assert (
|
|
|
|
sharded_shape[dim] % shard_partitions == 0
|
|
|
|
), f"Cannot shard dimension {dim} into {shard_partitions} partitions."
|
2023-03-07 03:08:11 +00:00
|
|
|
sharded_shape[dim] //= shard_partitions
|
|
|
|
return torch.Size(sharded_shape)
|
|
|
|
|
|
|
|
def _sanity_check(self):
|
|
|
|
sharding_spec = self.sharding_spec
|
|
|
|
|
|
|
|
# make sure all axes in logical device mesh only be used once
|
2023-06-15 10:03:38 +00:00
|
|
|
if self.device_mesh.logical_mesh_id is not None:
|
|
|
|
dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
|
|
|
|
for dim, shard_list in sharding_spec.dim_partition_dict.items():
|
|
|
|
for element in shard_list:
|
|
|
|
if element in dim_check_list:
|
|
|
|
dim_check_list.remove(element)
|
|
|
|
else:
|
|
|
|
raise DuplicatedShardingDimensionError(
|
2023-09-19 06:20:26 +00:00
|
|
|
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}."
|
|
|
|
)
|
2023-03-07 03:08:11 +00:00
|
|
|
|
|
|
|
# make sure that the sharding for a dimension is divisible by the number of devices
|
|
|
|
for dim, shard_list in sharding_spec.dim_partition_dict.items():
|
2023-06-26 07:50:07 +00:00
|
|
|
tensor_dim_size = self.global_shape[dim]
|
2023-03-07 03:08:11 +00:00
|
|
|
num_devices = 1
|
|
|
|
|
|
|
|
for element in shard_list:
|
2023-06-15 10:03:38 +00:00
|
|
|
num_devices *= self.device_mesh.shape[element]
|
2023-03-07 03:08:11 +00:00
|
|
|
|
|
|
|
if tensor_dim_size % num_devices != 0:
|
|
|
|
raise ShardingNotDivisibleError(
|
2023-09-19 06:20:26 +00:00
|
|
|
f"The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices."
|
2023-03-07 03:08:11 +00:00
|
|
|
)
|