From cd2b0eaa8dd4a7d8a67ce91b93459e07418bd741 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Tue, 7 Mar 2023 11:08:11 +0800 Subject: [PATCH] [DTensor] refactor sharding spec (#2987) * [autoparallel] refactor sharding spec * rename function name --- colossalai/tensor/d_tensor/__init__.py | 0 colossalai/tensor/d_tensor/layout.py | 58 ++++- colossalai/tensor/d_tensor/misc.py | 14 ++ colossalai/tensor/d_tensor/sharding_spec.py | 237 ++++++++++++++++++ .../{ => test_dtensor}/test_dtensor.py | 5 +- .../test_dtensor/test_sharding_spec.py | 34 +++ 6 files changed, 341 insertions(+), 7 deletions(-) create mode 100644 colossalai/tensor/d_tensor/__init__.py create mode 100644 colossalai/tensor/d_tensor/misc.py create mode 100644 colossalai/tensor/d_tensor/sharding_spec.py rename tests/test_tensor/{ => test_dtensor}/test_dtensor.py (94%) create mode 100644 tests/test_tensor/test_dtensor/test_sharding_spec.py diff --git a/colossalai/tensor/d_tensor/__init__.py b/colossalai/tensor/d_tensor/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/colossalai/tensor/d_tensor/layout.py b/colossalai/tensor/d_tensor/layout.py index 9b72444aa..72a2694a1 100644 --- a/colossalai/tensor/d_tensor/layout.py +++ b/colossalai/tensor/d_tensor/layout.py @@ -1,12 +1,15 @@ +import operator from dataclasses import dataclass +from functools import reduce import torch from colossalai.device.device_mesh import DeviceMesh -from colossalai.tensor.sharding_spec import ShardingSpec + +from .misc import DuplicatedShardingDimensionError, LayoutException, ShardingNotDivisibleError +from .sharding_spec import ShardingSpec -@dataclass class Layout: """Layout of a tensor. @@ -16,7 +19,50 @@ class Layout: sharding_spec: the sharding specification to describe how the tensor is sharded. entire_shape: the entire shape of the global tensor. """ - device_mesh: DeviceMesh - device_type: torch.device - sharding_spec: ShardingSpec - entire_shape: torch.Size = None + + def __init__(self, device_mesh: DeviceMesh, device_type: torch.device, sharding_spec: ShardingSpec, + entire_shape: torch.Size): + self.device_mesh = device_mesh + self.device_type = device_type + self.sharding_spec = sharding_spec + self.entire_shape = entire_shape + self._sanity_check() + + def __hash__(self) -> int: + return hash(f'{self.sharding_spec}') + + def get_sharded_shape_per_device(self): + sharded_shape = list(self.entire_shape) + for dim, shard_list in self.sharding_spec.dim_partition_dict.items(): + mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list] + shard_partitions = reduce(operator.mul, mesh_list, 1) + assert sharded_shape[ + dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.' + sharded_shape[dim] //= shard_partitions + return torch.Size(sharded_shape) + + def _sanity_check(self): + sharding_spec = self.sharding_spec + + # make sure all axes in logical device mesh only be used once + dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim())) + for dim, shard_list in sharding_spec.dim_partition_dict.items(): + for element in shard_list: + if element in dim_check_list: + dim_check_list.remove(element) + else: + raise DuplicatedShardingDimensionError( + f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.") + + # make sure that the sharding for a dimension is divisible by the number of devices + for dim, shard_list in sharding_spec.dim_partition_dict.items(): + tensor_dim_size = self.entire_shape[dim] + num_devices = 1 + + for element in shard_list: + num_devices *= self.device_mesh.mesh_shape[element] + + if tensor_dim_size % num_devices != 0: + raise ShardingNotDivisibleError( + f'The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.' + ) diff --git a/colossalai/tensor/d_tensor/misc.py b/colossalai/tensor/d_tensor/misc.py new file mode 100644 index 000000000..3bb3f6f19 --- /dev/null +++ b/colossalai/tensor/d_tensor/misc.py @@ -0,0 +1,14 @@ +class LayoutException(Exception): + pass + + +class DuplicatedShardingDimensionError(LayoutException): + pass + + +class ShardingNotDivisibleError(LayoutException): + pass + + +class ShardingOutOfIndexError(LayoutException): + pass diff --git a/colossalai/tensor/d_tensor/sharding_spec.py b/colossalai/tensor/d_tensor/sharding_spec.py new file mode 100644 index 000000000..b135c46d6 --- /dev/null +++ b/colossalai/tensor/d_tensor/sharding_spec.py @@ -0,0 +1,237 @@ +from copy import deepcopy +from typing import Dict, List + +from ..utils import merge_same_dim_mesh_list +from .misc import ShardingOutOfIndexError + +__all__ = ['DimSpec', 'ShardingException', 'ShardingSpec'] + +ALLGATHER_COST = 20 +SHARD_COST = 5 +STEP_PENALTY = 6 +NAN = 'nan' + + +class DimSpec: + ''' + Sharding spec for single dimension of the sharded tensor decribe the sharding dimension of + logical device mesh and give a method to compute the difference between them. + This class is used internally in ShardingSpec. + + Argument: + shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type. + Otherwise, the element in shard_list means the data will be sharded in that dimension. + ''' + + def __init__(self, shard_list): + self.is_replica = len(shard_list) == 0 + self.shard_list = shard_list + self.build_difference_2d_dict() + + def __eq__(self, other): + return str(self) == str(other) + + def __repr__(self): + if self.is_replica: + return 'R' + target = 'S' + for dim in self.shard_list: + target += str(dim) + return target + + def _convert_str_to_shard_list(self, str_spec): + ''' + Conver str_spec into shard_list. + + Argument: + str_spec(str): dim spec in str type. + ''' + + if str_spec == 'R': + return [] + if str_spec == 'S0': + return [0] + if str_spec == 'S1': + return [1] + if str_spec == 'S01': + return [0, 1] + + def build_difference_2d_dict(self): + ''' + Build a difference maping for 2D device mesh case. It will be used to + compute the difference between DimSpec pairs. + ''' + + source_spec_list = ['R', 'S0', 'S1', 'S01'] + target_spec_list = ['R', 'S0', 'S1', 'S01'] + difference_dict = {} + for source_spec in source_spec_list: + for target_spec in target_spec_list: + legal_sharding_dims = [] + spec_pair = (deepcopy(source_spec), deepcopy(target_spec)) + source_shard_list = self._convert_str_to_shard_list(source_spec) + target_shard_list = self._convert_str_to_shard_list(target_spec) + + # source same as target + if source_shard_list == target_shard_list: + difference = 0 + + # all_gather(source) -> target + elif len(source_shard_list + ) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list: + difference = ALLGATHER_COST + + # shard(source) -> target + elif len(source_shard_list) == len( + target_shard_list) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[ + -1] not in source_shard_list: + difference = SHARD_COST + + # S1 -> S0 or S0 -> S1 + elif len(source_shard_list) == len(target_shard_list): + # source -> R -> target + difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + + # R -> S01 + elif len(source_shard_list) == len(target_shard_list) - 2: + difference = SHARD_COST + STEP_PENALTY + SHARD_COST + + # S01 -> R + elif len(source_shard_list) == len(target_shard_list) + 2: + difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + + # S1 -> S01 + elif len(source_shard_list) == len(target_shard_list) - 1: + difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + STEP_PENALTY + SHARD_COST + + # S01 -> S1 + elif len(source_shard_list) == len(target_shard_list) + 1: + difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + STEP_PENALTY + SHARD_COST + + else: + difference = NAN + difference_dict[spec_pair] = difference + + self.difference_dict = difference_dict + + def dim_diff(self, other): + ''' + The difference between two _DimSpec. + + Argument: + other(_DimSpec): the dim spec to compare with. + + Return: + difference(int): the difference between two _DimSpec. + + Example: + dim_spec = _DimSpec([0]) + other_dim_spec = _DimSpec([0, 1]) + print(dim_spec.difference(other_dim_spec)) + + Output: + 5 + ''' + difference = self.difference_dict[(str(self), str(other))] + return difference + + +class ShardingSpec: + ''' + Sharding spec describes how to shard a tensor with dim_size dimensions. The sharding sequence looks like + [R, R, S0, S1], which means + + Argument: + dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded, + and the value of the key decribe which logical axis will be sharded in that dimension. + sharding_sequence(List[DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1]. + ''' + + def __init__(self, + dim_size: int, + dim_partition_dict: Dict[int, List[int]] = None, + sharding_sequence: List[DimSpec] = None): + self.dims = dim_size + self.dim_partition_dict = dim_partition_dict + self.sharding_sequence = sharding_sequence + if self.sharding_sequence is None: + assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.' + self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=self.dims, + dim_partition_dict=self.dim_partition_dict) + self.sharding_sequence = self.convert_dict_to_shard_sequence() + + elif self.dim_partition_dict is None: + assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.' + self.dim_partition_dict = self.convert_shard_sequence_to_dict() + + self._sanity_check() + + def _sanity_check(self): + if len(self.sharding_sequence) > self.dims: + raise ShardingOutOfIndexError( + f'sharding_sequence should have {self.dims} elements, but got index {len(self.sharding_sequence)}.') + + if max(list(self.dim_partition_dict.keys())) >= self.dims: + raise ShardingOutOfIndexError( + f'the key of dim_partition_dict should be less than {self.dims}, but got {max(list(self.dim_partition_dict.keys()))}.' + ) + + def __repr__(self): + res_list = ["ShardingSpec:"] + res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence)) + return ' '.join(res_list) + + def convert_dict_to_shard_sequence(self): + ''' + Convert dim_partition_dict into list of DimSpec, and assign it to sharding_sequence. + ''' + sharding_sequence = [DimSpec([])] * self.dims + for dim, shard_list in self.dim_partition_dict.items(): + sharding_sequence[dim] = DimSpec(shard_list) + return sharding_sequence + + def convert_shard_sequence_to_dict(self): + ''' + Convert sharding_sequence into dim_partition_dict. + ''' + new_dim_partition_dict = {} + for index, dim_spec in enumerate(self.sharding_sequence): + if not dim_spec.is_replica: + if index not in new_dim_partition_dict: + new_dim_partition_dict[index] = [] + new_dim_partition_dict[index].extend(dim_spec.shard_list) + return new_dim_partition_dict + + def spec_diff(self, other): + ''' + This function is a naive version of difference computation. It just simply accumulates difference every dimension between the + pair of sharding sequence. + + Example: + dim_partition_dict = {0: [0, 1]} + # DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4) + sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) + dim_partition_dict_to_compare = {0: [0], 1: [1]} + # DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4) + sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare) + print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare)) + + Output: + 25 + + Argument: + other(ShardingSpec): The ShardingSpec to compared with. + + Return: + difference(int): Difference between two ShardingSpec. + ''' + assert len(self.sharding_sequence) == len( + other.sharding_sequence), f'Cannot compare difference for two sharding specs with different length.' + difference = 0 + for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence, other.sharding_sequence): + difference += orig_dim_spec.dim_diff(other_dim_spec) + return difference diff --git a/tests/test_tensor/test_dtensor.py b/tests/test_tensor/test_dtensor/test_dtensor.py similarity index 94% rename from tests/test_tensor/test_dtensor.py rename to tests/test_tensor/test_dtensor/test_dtensor.py index 1de9563a2..80e275d97 100644 --- a/tests/test_tensor/test_dtensor.py +++ b/tests/test_tensor/test_dtensor/test_dtensor.py @@ -37,7 +37,10 @@ def check_dtensor(rank, world_size, port): target_sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=original_tensor.shape, dim_partition_dict={0: [0]}) - layout = Layout(device_mesh=device_mesh, device_type=torch.device('cuda'), sharding_spec=target_sharding_spec) + layout = Layout(device_mesh=device_mesh, + device_type=torch.device('cuda'), + sharding_spec=target_sharding_spec, + entire_shape=original_tensor.shape) d_tensor = DTensor(original_tensor, layout) assert d_tensor.entire_shape == original_tensor.shape diff --git a/tests/test_tensor/test_dtensor/test_sharding_spec.py b/tests/test_tensor/test_dtensor/test_sharding_spec.py new file mode 100644 index 000000000..e02f71048 --- /dev/null +++ b/tests/test_tensor/test_dtensor/test_sharding_spec.py @@ -0,0 +1,34 @@ +import operator +from functools import reduce + +from colossalai.tensor.d_tensor.sharding_spec import ALLGATHER_COST, SHARD_COST, STEP_PENALTY, ShardingSpec + + +def test_sharding_spec(): + dims = 4 + dim_partition_dict_0 = {0: [0, 1]} + # DistSpec: + # shard_sequence: S01,R,R,R + sharding_spec_0 = ShardingSpec(dims, dim_partition_dict=dim_partition_dict_0) + assert str(sharding_spec_0.sharding_sequence) == "[S01, R, R, R]" + + dim_partition_dict_1 = {1: [0, 1]} + # DistSpec: + # shard_sequence: R,S01,R,R + sharding_spec_1 = ShardingSpec(dims, dim_partition_dict=dim_partition_dict_1) + assert str(sharding_spec_1.sharding_sequence) == "[R, S01, R, R]" + + dim_spec_list_0 = [dim_spec for dim_spec in sharding_spec_0.sharding_sequence] + dim_spec_list_1 = [dim_spec for dim_spec in sharding_spec_1.sharding_sequence] + + assert dim_spec_list_0[0].dim_diff(dim_spec_list_1[0]) == ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + assert dim_spec_list_0[1].dim_diff(dim_spec_list_1[1]) == SHARD_COST + STEP_PENALTY + SHARD_COST + assert dim_spec_list_0[2].dim_diff(dim_spec_list_1[2]) == 0 + assert dim_spec_list_0[3].dim_diff(dim_spec_list_1[3]) == 0 + + assert sharding_spec_0.spec_diff(sharding_spec_1) == \ + reduce(operator.add, [dim_spec_list_0[i].dim_diff(dim_spec_list_1[i]) for i in range(dims)], 0) + + +if __name__ == '__main__': + test_sharding_spec()