From 33f0744d51db45d1df061a65c2629f34fd4b18dd Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Wed, 10 Aug 2022 11:29:17 +0800 Subject: [PATCH] [tensor] add shape consistency feature to support auto spec transform (#1418) * [tensor] add shape consistency feature to supportauto sharding spec transform. * [tensor] remove unused argument in simulator, add doc string for target pair. --- colossalai/tensor/shape_consistency.py | 320 ++++++++++++++++++++ colossalai/tensor/sharding_spec.py | 23 +- tests/test_tensor/test_shape_consistency.py | 86 ++++++ 3 files changed, 424 insertions(+), 5 deletions(-) create mode 100644 colossalai/tensor/shape_consistency.py create mode 100644 tests/test_tensor/test_shape_consistency.py diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py new file mode 100644 index 000000000..1f4c5f1f3 --- /dev/null +++ b/colossalai/tensor/shape_consistency.py @@ -0,0 +1,320 @@ +import torch +from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec +from enum import Enum +from copy import deepcopy + + +class CollectiveCommPattern(Enum): + ALLGATHER = 'all_gather' + ALLTOALL = 'all_to_all' + SHARD = 'shard' + + +class ShapeConsistencyManager: + + def __init__(self, consistency_option=None): + self.consistency_option = consistency_option + self.total_communication_cost = 0 + self.total_transform_steps = 0 + self.cached_spec_pairs = {} + + def _all_gather_simulator(self, target_pair): + ''' + Simulating all-gather operation, analyze the communication cost + and simulate the influence of the DimSpec. + + We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed. + Therefore, all gather operation just remove the last element in shard list, + e.g.: + all-gather(S01) -> S0 + + Argument: + target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, + and the second element decribes which logical axis will be sharded in that dimension. + ''' + _, shard_list = target_pair + new_shard_list = shard_list[:-1] + # TODO: compute comm cost + comm_cost = 0 + return new_shard_list, comm_cost + + def _all_to_all_simulator(self, f_target_pair, b_target_pair): + ''' + Simulating all-to-all operation, analyze the communication cost + and simulate the influence of the DimSpec. + + We BANNED all representations which shard_list in decreasing order, + such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed. + Therefore, if the behind shard_list is not None, we just extend it to the front shard_list. + Argument: + target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, + and the second element decribes which logical axis will be sharded in that dimension. + e.g.: + all-to-all(S0, S1) -> [S01, R] + all-to-all(S0, R) -> [R, S0] + Otherwise, we extend the front shard_list to behind. + e.g.: + all-to-all(R, S1) -> [S1, R] + + Argument: + target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, + and the second element decribes which logical axis will be sharded in that dimension. + ''' + _, f_shard_list = f_target_pair + _, b_shard_list = b_target_pair + if not len(b_shard_list): + b_shard_list.extend(f_shard_list) + f_shard_list = [] + else: + f_shard_list.extend(b_shard_list) + b_shard_list = [] + # TODO: compute comm cost + comm_cost = 0 + return f_shard_list, b_shard_list, comm_cost + + def _shard_simulator(self, target_pair, legal_sharding_dims): + ''' + Simulating shard operation, analyze the communication cost(always ZERO) + and simulate the influence of the DimSpec. + + We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed. + In addition, We BANNED all representations which shard_list in decreasing order, + such as S10, so shard(S0) -> S10 is NOT allowed. + Therefore, for the R dimension, we could just append any legal sharding dim on it. + e.g.: + shard(R) -> S0 + For the S dimension, we need to make sure the shard_list after sharding still keep rising order. + e.g: + shard(S0) -> S01 + + Argument: + target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, + and the second element decribes which logical axis will be sharded in that dimension. + ''' + _, shard_list = target_pair + shard_list_list = [] + for dim in legal_sharding_dims: + if len(shard_list) != 0 and dim <= shard_list[-1]: + continue + new_shard_list = shard_list + [dim] + shard_list_list.append(new_shard_list) + comm_cost = 0 + return shard_list_list, comm_cost + + def get_all_all_gather_spec(self, source_spec, orig_cost): + ''' + Get all valid sharding specs from source_spec with single all-gather operation, and + accumulate commucation cost on origin cost which will finally be used in auto sharding solver. + For the all-gather operation, we just care about the S dimension. + + Argument: + source_spec(ShardingSpec): the ShardingSpec of the source_spec. + orig_cost(float): the original communication cost before this operation. + + Return: + valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-gather operation. + + Example: + dim_partition_dict = {0: [0], 1: [1]} + # DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4) + sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) + shape_consistency_manager = ShapeConsistencyManager() + rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0) + print(rst_dict) + + Output: + {DistSpec: + shard_sequence: R,S1,R + device_mesh_shape: (4, 4): 0, DistSpec: + shard_sequence: S0,R,R + device_mesh_shape: (4, 4): 0} + ''' + valid_spec_dict = {} + for target_pair in source_spec.dim_partition_dict.items(): + shard_list, cost = self._all_gather_simulator(target_pair) + index = target_pair[0] + new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) + new_dim_partition_dict[index] = shard_list + new_sharding_spec = ShardingSpec(source_spec.device_mesh, + source_spec.entire_shape, + dim_partition_dict=new_dim_partition_dict) + valid_spec_dict[new_sharding_spec] = orig_cost + cost + return valid_spec_dict + + def get_all_all_to_all_spec(self, source_spec, orig_cost): + ''' + Get all valid sharding specs from source_spec with single all-to-all operation, and + accumulate commucation cost on origin cost which will finally be used in auto sharding solver. + For the all-to-all operation, we just care about the pairs containing S dimension. + + Argument: + source_spec(ShardingSpec): the ShardingSpec of the source_spec. + orig_cost(float): the original communication cost before this operation. + + Return: + valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. + + Example: + dim_partition_dict = {0: [0], 1: [1]} + # DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4) + sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) + shape_consistency_manager = ShapeConsistencyManager() + rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, 0) + print(rst_dict) + + Output: + {DistSpec: + shard_sequence: S01,R,R + device_mesh_shape: (4, 4): 0, DistSpec: + shard_sequence: R,S1,S0 + device_mesh_shape: (4, 4): 0, DistSpec: + shard_sequence: S0,R,S1 + device_mesh_shape: (4, 4): 0} + ''' + valid_spec_dict = {} + tensor_dims = len(source_spec.entire_shape) + for f_index in range(tensor_dims - 1): + for b_index in range(f_index + 1, tensor_dims): + # skip (R, R) cases + if f_index not in source_spec.dim_partition_dict and b_index not in source_spec.dim_partition_dict: + continue + else: + if f_index in source_spec.dim_partition_dict: + f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index])) + else: + f_target_pair = (f_index, []) + if b_index in source_spec.dim_partition_dict: + b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index])) + else: + b_target_pair = (b_index, []) + + f_shard_list, b_shard_list, cost = self._all_to_all_simulator(f_target_pair, b_target_pair) + f_index = f_target_pair[0] + b_index = b_target_pair[0] + new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) + new_dim_partition_dict[f_index] = f_shard_list + new_dim_partition_dict[b_index] = b_shard_list + new_sharding_spec = ShardingSpec(source_spec.device_mesh, + source_spec.entire_shape, + dim_partition_dict=new_dim_partition_dict) + valid_spec_dict[new_sharding_spec] = orig_cost + cost + return valid_spec_dict + + def get_all_shard_spec(self, source_spec, orig_cost): + ''' + Get all valid sharding specs from source_spec with single shard operation, and + accumulate commucation cost on origin cost which will finally be used in auto sharding solver. + For the sharding operation, we just care about legal sharding dimensions. + + Argument: + source_spec(ShardingSpec): the ShardingSpec of the source_spec. + orig_cost(float): the original communication cost before this operation. + + Return: + valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. + + Example: + dim_partition_dict = {0: [0]} + # DistSpec: + # shard_sequence: S0,R,R + # device_mesh_shape: (4, 4) + sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) + shape_consistency_manager = ShapeConsistencyManager() + rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, 0) + print(rst_dict) + + Output: + {DistSpec: + shard_sequence: S01,R,R + device_mesh_shape: (4, 4): 0, DistSpec: + shard_sequence: S0,S1,R + device_mesh_shape: (4, 4): 0, DistSpec: + shard_sequence: S0,R,S1 + device_mesh_shape: (4, 4): 0} + ''' + valid_spec_dict = {} + legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.mesh_shape))] + for dim, shard_list in source_spec.dim_partition_dict.items(): + for element in shard_list: + legal_sharding_dims.remove(element) + if len(legal_sharding_dims) == 0: + return valid_spec_dict + + tensor_dims = len(source_spec.entire_shape) + for index in range(tensor_dims): + if index not in source_spec.dim_partition_dict: + shard_list_list, cost = self._shard_simulator((index, []), legal_sharding_dims) + else: + shard_list_list, cost = self._shard_simulator((index, source_spec.dim_partition_dict[index]), + legal_sharding_dims) + if not shard_list_list: + continue + for shard_list in shard_list_list: + new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) + new_dim_partition_dict[index] = shard_list + new_sharding_spec = ShardingSpec(source_spec.device_mesh, + source_spec.entire_shape, + dim_partition_dict=new_dim_partition_dict) + valid_spec_dict[new_sharding_spec] = orig_cost + cost + return valid_spec_dict + + def get_all_one_step_transform_spec(self, source_spec, orig_cost): + ''' + Get all valid sharding specs from source_spec with one step transform, and + accumulate commucation cost on origin cost which will finally be used in auto sharding solver. + Note: + all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before, + and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive, + we could safely put them together. + + Argument: + source_spec(ShardingSpec): the ShardingSpec of the source_spec. + orig_cost(float): the original communication cost before this operation. + + Return: + valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. + ''' + valid_spec_dict = {} + valid_spec_dict.update(self.get_all_all_gather_spec(source_spec, orig_cost)) + valid_spec_dict.update(self.get_all_all_to_all_spec(source_spec, orig_cost)) + valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost)) + return valid_spec_dict + + def shape_consistency(self, source_spec, target_spec): + ''' + This method will find a path to transform source_spec to target_spec with + a greedy algorithm. + The basic idea is: + Step1: + Generate all one-step transform sequences from source_spec. + Step2: + Pick the 'best' sharding spec following the heuristic function. + Step3: + Repeat above steps until the source spec transform to target spec. + + This function is NOT completed, due to absense of difference function. + ''' + MAX_TRANSFORM_STEPS = 10 + total_cost = 0 + total_steps = 0 + transform_path = [] + temp_sharding_spec = deepcopy(source_spec) + transform_path.append(temp_sharding_spec) + while total_steps <= MAX_TRANSFORM_STEPS: + valid_transform_spec_dict = get_all_one_step_transform_spec(temp_sharding_spec) + best_difference_score = 0 + for sharding_spec, cost in valid_transform_spec_dict.items(): + if no_difference(sharding_spec, target_spec): + total_cost += cost + transform_path.append(sharding_spec) + return (transform_path, total_cost) + if difference(sharding_spec, target_spec) > best_difference_score: + temp_sharding_spec = deepcopy(sharding_spec) + temp_cost = cost + transform_path.append(temp_sharding_spec) + total_cost += temp_cost + return (transform_path, total_cost) diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py index 7162a4ed0..e4f7f2490 100644 --- a/colossalai/tensor/sharding_spec.py +++ b/colossalai/tensor/sharding_spec.py @@ -13,7 +13,7 @@ class _DimSpec: ''' def __init__(self, shard_list): - self.is_replica = shard_list is None + self.is_replica = len(shard_list) == 0 self.shard_list = shard_list def __eq__(self, other): @@ -52,12 +52,16 @@ class ShardingSpec: and the value of the key decribe which logical axis will be sharded in that dimension. ''' - def __init__(self, device_mesh, entire_shape, dim_partition_dict): + def __init__(self, device_mesh, entire_shape, dim_partition_dict=None, sharding_sequence=None): self.device_mesh = device_mesh self.entire_shape = entire_shape self.dim_partition_dict = dim_partition_dict + self.sharding_sequence = sharding_sequence + if self.sharding_sequence is None: + self.convert_dict_to_shard_sequence() + elif self.dim_partition_dict is None: + self.convert_shard_sequence_to_dict() self._sanity_check() - self.sharding_sequence = self.convert_dict_to_shard_sequence() def __repr__(self): res_list = ["DistSpec:"] @@ -80,10 +84,19 @@ class ShardingSpec: f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.") def convert_dict_to_shard_sequence(self): - sharding_sequence = [_DimSpec(None)] * len(self.entire_shape) + sharding_sequence = [_DimSpec([])] * len(self.entire_shape) for dim, shard_list in self.dim_partition_dict.items(): sharding_sequence[dim] = _DimSpec(shard_list) - return sharding_sequence + self.sharding_sequence = sharding_sequence + + def convert_shard_sequence_to_dict(self): + new_dim_partition_dict = {} + for index, dim_spec in enumerate(self.sharding_sequence): + if not dim_spec.is_replica: + if index not in new_dim_partition_dict: + new_dim_partition_dict[index] = [] + new_dim_partition_dict[index].append(dim_spec.shard_list) + self.dim_partition_dict = new_dim_partition_dict def sharding_sequence_difference(self, other): ''' diff --git a/tests/test_tensor/test_shape_consistency.py b/tests/test_tensor/test_shape_consistency.py new file mode 100644 index 000000000..634c5b85b --- /dev/null +++ b/tests/test_tensor/test_shape_consistency.py @@ -0,0 +1,86 @@ +from colossalai.tensor.shape_consistency import ShapeConsistencyManager +import torch +from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec +from colossalai.device.device_mesh import DeviceMesh + + +def test_shape_consistency(): + physical_mesh_id = torch.arange(0, 16).reshape(2, 8) + mesh_shape = (4, 4) + # [[0, 1, 2, 3], + # [4, 5, 6, 7], + # [8, 9, 10,11], + # [12,13,14,15]] + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + entire_shape = torch.Size((4, 8, 6)) + dim_partition_dict = {0: [0], 1: [1]} + # DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4) + sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) + shape_consistency_manager = ShapeConsistencyManager() + # {DistSpec: + # shard_sequence: R,S1,R + # device_mesh_shape: (4, 4): 0, DistSpec: + # shard_sequence: S0,R,R + # device_mesh_shape: (4, 4): 0} + rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0) + + assert '[R, S1, R]' in [ + str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys() + ] + assert '[S0, R, R]' in [ + str(all_gather_sharding_spec.sharding_sequence) for all_gather_sharding_spec in rst_dict.keys() + ] + + dim_partition_dict_all2all = {0: [0], 1: [1]} + # DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4) + sharding_spec_all2all = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_all2all) + # {DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4): 0, DistSpec: + # shard_sequence: R,S1,S0 + # device_mesh_shape: (4, 4): 0, DistSpec: + # shard_sequence: S0,R,S1 + # device_mesh_shape: (4, 4): 0} + rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec_all2all, 0) + + assert '[S01, R, R]' in [ + str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys() + ] + assert '[R, S1, S0]' in [ + str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys() + ] + assert '[S0, R, S1]' in [ + str(all2all_sharding_spec.sharding_sequence) for all2all_sharding_spec in rst_dict_all2all.keys() + ] + + dim_partition_shard = {0: [0]} + # DistSpec: + # shard_sequence: S0,R,R + # device_mesh_shape: (4, 4) + sharding_spec_shard = ShardingSpec(device_mesh, entire_shape, dim_partition_shard) + # {DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4): 0, DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4): 0, DistSpec: + # shard_sequence: S0,R,S1 + # device_mesh_shape: (4, 4): 0} + rst_dict_shard = shape_consistency_manager.get_all_shard_spec(sharding_spec_shard, 0) + + assert '[S01, R, R]' in [ + str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys() + ] + assert '[S0, S1, R]' in [ + str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys() + ] + assert '[S0, R, S1]' in [ + str(shard_sharding_spec.sharding_sequence) for shard_sharding_spec in rst_dict_shard.keys() + ] + + +if __name__ == '__main__': + test_shape_consistency()