diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py index 1f4c5f1f3..cd7fa4d25 100644 --- a/colossalai/tensor/shape_consistency.py +++ b/colossalai/tensor/shape_consistency.py @@ -1,7 +1,11 @@ import torch from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec +from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator from enum import Enum from copy import deepcopy +import math +from functools import reduce +import operator class CollectiveCommPattern(Enum): @@ -10,96 +14,71 @@ class CollectiveCommPattern(Enum): SHARD = 'shard' +class CommSpec: + ''' + Communication spec is used to record the communication action. It has two main functions: + 1. Compute the communication cost which will be used in auto parallel solver. + 2. Convert the communication spec to real action which will be used in runtime. + It contains comm_pattern to determine the + communication method, sharding_spec to determine the communication size, gather_dim and shard_dim + to determine the buffer shape, and logical_process_axis + + Argument: + comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec. + sharding_spec(ShardingSpec): This is sharding spec of the tensor which will join the communication action. + gather_dim(int, optional): The gather_dim of the tensor will be gathered. + shard_dim(int, optional): The shard_dim of the tensor will be sharded. + logical_process_axis(int, optional): The mesh_dim to implement the communication action. + ''' + + def __init__(self, comm_pattern, sharding_spec, gather_dim=None, shard_dim=None, logical_process_axis=None): + self.comm_pattern = comm_pattern + self.sharding_spec = sharding_spec + self.gather_dim = gather_dim + self.shard_dim = shard_dim + self.logical_process_axis = logical_process_axis + + def __repr__(self): + res_list = ["CommSpec:("] + if self.comm_pattern == CollectiveCommPattern.ALLGATHER: + res_list.append(f"comm_pattern:allgather, ") + res_list.append(f"gather_dim:{self.gather_dim}, ") + res_list.append(f"logical_process_axis:{self.logical_process_axis})") + elif self.comm_pattern == CollectiveCommPattern.ALLTOALL: + res_list.append(f"comm_pattern:all2all, ") + res_list.append(f"gather_dim:{self.gather_dim}, ") + res_list.append(f"shard_dim:{self.shard_dim}, ") + res_list.append(f"logical_process_axis: {self.logical_process_axis})") + else: + res_list.append(f"comm_pattern:shard, ") + res_list.append(f"shard_dim:{self.shard_dim}, ") + res_list.append(f"logical_process_axis:{self.logical_process_axis})") + return ''.join(res_list) + + def get_comm_cost(self): + ''' + For all_gather and all2all operation, the formula provided in DeviceMesh with alpha-beta model is used to + compute the communication cost. + For shard operation, it is an on-chip operation, so the communication cost is zero. + ''' + comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1) + if self.comm_pattern == CollectiveCommPattern.ALLGATHER: + return self.sharding_spec.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) + if self.comm_pattern == CollectiveCommPattern.ALLTOALL: + return self.sharding_spec.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis) + return 0 + + def covert_spec_to_action(self): + pass + + class ShapeConsistencyManager: def __init__(self, consistency_option=None): self.consistency_option = consistency_option self.total_communication_cost = 0 self.total_transform_steps = 0 - self.cached_spec_pairs = {} - - def _all_gather_simulator(self, target_pair): - ''' - Simulating all-gather operation, analyze the communication cost - and simulate the influence of the DimSpec. - - We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed. - Therefore, all gather operation just remove the last element in shard list, - e.g.: - all-gather(S01) -> S0 - - Argument: - target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, - and the second element decribes which logical axis will be sharded in that dimension. - ''' - _, shard_list = target_pair - new_shard_list = shard_list[:-1] - # TODO: compute comm cost - comm_cost = 0 - return new_shard_list, comm_cost - - def _all_to_all_simulator(self, f_target_pair, b_target_pair): - ''' - Simulating all-to-all operation, analyze the communication cost - and simulate the influence of the DimSpec. - - We BANNED all representations which shard_list in decreasing order, - such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed. - Therefore, if the behind shard_list is not None, we just extend it to the front shard_list. - Argument: - target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, - and the second element decribes which logical axis will be sharded in that dimension. - e.g.: - all-to-all(S0, S1) -> [S01, R] - all-to-all(S0, R) -> [R, S0] - Otherwise, we extend the front shard_list to behind. - e.g.: - all-to-all(R, S1) -> [S1, R] - - Argument: - target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, - and the second element decribes which logical axis will be sharded in that dimension. - ''' - _, f_shard_list = f_target_pair - _, b_shard_list = b_target_pair - if not len(b_shard_list): - b_shard_list.extend(f_shard_list) - f_shard_list = [] - else: - f_shard_list.extend(b_shard_list) - b_shard_list = [] - # TODO: compute comm cost - comm_cost = 0 - return f_shard_list, b_shard_list, comm_cost - - def _shard_simulator(self, target_pair, legal_sharding_dims): - ''' - Simulating shard operation, analyze the communication cost(always ZERO) - and simulate the influence of the DimSpec. - - We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed. - In addition, We BANNED all representations which shard_list in decreasing order, - such as S10, so shard(S0) -> S10 is NOT allowed. - Therefore, for the R dimension, we could just append any legal sharding dim on it. - e.g.: - shard(R) -> S0 - For the S dimension, we need to make sure the shard_list after sharding still keep rising order. - e.g: - shard(S0) -> S01 - - Argument: - target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, - and the second element decribes which logical axis will be sharded in that dimension. - ''' - _, shard_list = target_pair - shard_list_list = [] - for dim in legal_sharding_dims: - if len(shard_list) != 0 and dim <= shard_list[-1]: - continue - new_shard_list = shard_list + [dim] - shard_list_list.append(new_shard_list) - comm_cost = 0 - return shard_list_list, comm_cost + self.cached_spec_pairs_transform_path = {} def get_all_all_gather_spec(self, source_spec, orig_cost): ''' @@ -132,15 +111,35 @@ class ShapeConsistencyManager: device_mesh_shape: (4, 4): 0} ''' valid_spec_dict = {} + comm_pattern = CollectiveCommPattern.ALLGATHER for target_pair in source_spec.dim_partition_dict.items(): - shard_list, cost = self._all_gather_simulator(target_pair) + shard_list = all_gather_simulator(target_pair) index = target_pair[0] new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) - new_dim_partition_dict[index] = shard_list + + # We won't add empty list into dim_partition_dict + # The key will be popped if the related shard_list is empty + if shard_list: + new_dim_partition_dict[index] = shard_list + else: + new_dim_partition_dict.pop(index) + + # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec + gather_dim = index + logical_process_axis = target_pair[1][-1] + comm_spec = CommSpec(comm_pattern, + sharding_spec=source_spec, + gather_dim=gather_dim, + logical_process_axis=logical_process_axis) + + # compute the communication cost with CommSpec + cost = comm_spec.get_comm_cost() + + # generate new sharding spec new_sharding_spec = ShardingSpec(source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict) - valid_spec_dict[new_sharding_spec] = orig_cost + cost + valid_spec_dict[new_sharding_spec] = (comm_spec, orig_cost + cost) return valid_spec_dict def get_all_all_to_all_spec(self, source_spec, orig_cost): @@ -176,6 +175,7 @@ class ShapeConsistencyManager: device_mesh_shape: (4, 4): 0} ''' valid_spec_dict = {} + comm_pattern = CollectiveCommPattern.ALLTOALL tensor_dims = len(source_spec.entire_shape) for f_index in range(tensor_dims - 1): for b_index in range(f_index + 1, tensor_dims): @@ -184,24 +184,62 @@ class ShapeConsistencyManager: continue else: if f_index in source_spec.dim_partition_dict: + # skip (S01, R) -> (R, S01) is NOT allowed + if len(source_spec.dim_partition_dict[f_index]) >= 2: + continue f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index])) else: f_target_pair = (f_index, []) if b_index in source_spec.dim_partition_dict: + # skip (R, R) -> (R, S01) is NOT allowed + if len(source_spec.dim_partition_dict[b_index]) >= 2: + continue b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index])) else: b_target_pair = (b_index, []) - f_shard_list, b_shard_list, cost = self._all_to_all_simulator(f_target_pair, b_target_pair) + # skip (S1, S0) -> S10 + if f_target_pair[1] and b_target_pair[1] and f_target_pair[1][0] >= b_target_pair[1][0]: + continue + f_shard_list, b_shard_list = all_to_all_simulator(f_target_pair, b_target_pair) f_index = f_target_pair[0] b_index = b_target_pair[0] + + # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec + if len(f_shard_list) < len(f_target_pair[1]): + gather_dim = f_index + shard_dim = b_index + logical_process_axis = f_target_pair[1][-1] + else: + gather_dim = b_index + shard_dim = f_index + logical_process_axis = b_target_pair[1][-1] + comm_spec = CommSpec(comm_pattern, + sharding_spec=source_spec, + gather_dim=gather_dim, + shard_dim=shard_dim, + logical_process_axis=logical_process_axis) + + # compute the communication cost with CommSpec + cost = comm_spec.get_comm_cost() new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) - new_dim_partition_dict[f_index] = f_shard_list - new_dim_partition_dict[b_index] = b_shard_list + + # We won't add empty list into dim_partition_dict + # The key will be popped if the related shard_list is empty + if f_shard_list: + new_dim_partition_dict[f_index] = f_shard_list + else: + new_dim_partition_dict.pop(f_index) + if b_shard_list: + new_dim_partition_dict[b_index] = b_shard_list + else: + new_dim_partition_dict.pop(b_index) + + # generate new sharding spec new_sharding_spec = ShardingSpec(source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict) - valid_spec_dict[new_sharding_spec] = orig_cost + cost + valid_spec_dict[new_sharding_spec] = (comm_spec, orig_cost + cost) return valid_spec_dict def get_all_shard_spec(self, source_spec, orig_cost): @@ -237,6 +275,9 @@ class ShapeConsistencyManager: device_mesh_shape: (4, 4): 0} ''' valid_spec_dict = {} + comm_pattern = CollectiveCommPattern.SHARD + + # legal sharding dims means the mesh_id is still available to use. legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.mesh_shape))] for dim, shard_list in source_spec.dim_partition_dict.items(): for element in shard_list: @@ -247,19 +288,31 @@ class ShapeConsistencyManager: tensor_dims = len(source_spec.entire_shape) for index in range(tensor_dims): if index not in source_spec.dim_partition_dict: - shard_list_list, cost = self._shard_simulator((index, []), legal_sharding_dims) + shard_list_list = shard_simulator((index, []), legal_sharding_dims) else: - shard_list_list, cost = self._shard_simulator((index, source_spec.dim_partition_dict[index]), - legal_sharding_dims) + shard_list_list = shard_simulator((index, source_spec.dim_partition_dict[index]), legal_sharding_dims) if not shard_list_list: continue for shard_list in shard_list_list: new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) new_dim_partition_dict[index] = shard_list + + # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec + shard_dim = index + logical_process_axis = shard_list[-1] + comm_spec = CommSpec(comm_pattern, + sharding_spec=source_spec, + shard_dim=shard_dim, + logical_process_axis=logical_process_axis) + + # compute the communication cost with CommSpec + cost = comm_spec.get_comm_cost() + + # generate new sharding spec new_sharding_spec = ShardingSpec(source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict) - valid_spec_dict[new_sharding_spec] = orig_cost + cost + valid_spec_dict[new_sharding_spec] = (comm_spec, orig_cost + cost) return valid_spec_dict def get_all_one_step_transform_spec(self, source_spec, orig_cost): @@ -296,25 +349,93 @@ class ShapeConsistencyManager: Step3: Repeat above steps until the source spec transform to target spec. - This function is NOT completed, due to absense of difference function. + During finding the transform path, commucation cost will be accumulated, and it + will be finally used in auto parallel solver. + + Additionally, to avoid repeating the path search in runtime, we cached all solved path + in auto parallel strategy building time, which could handle most of cases in runtime. + + Argument: + source_spec(ShardingSpec): ShardingSpec of the source activation. + target_spec(ShardingSpec): ShardingSpec of the target activation. + + Return: + transform_path(List[ShardingSpec]): The transform path from source_spec to target_spec, + it contains the source_spec and target_spec. + comm_action_sequence(List[CommSpec]): Keep the communication operations to complete the shape consistency in order. + total_cost(float): total cost to complete shape consistency transform. + + Example: + dim_partition_source = {1: [0, 1]} + dim_partition_target = {0: [0, 1]} + # DistSpec: + # shard_sequence: R,S01,R + # device_mesh_shape: (4, 4) + sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source) + # DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4) + sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target) + transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(sharding_spec_source, sharding_spec_target) + print(f'transform_path: {transform_path}') + print(f'comm_action_sequence: {comm_action_sequence}') + print(f'total_cost: {total_cost}') + + output: + transform_path: [DistSpec: + shard_sequence: R,S01,R + device_mesh_shape: (4, 4), DistSpec: + shard_sequence: R,S0,R + device_mesh_shape: (4, 4), DistSpec: + shard_sequence: S0,R,R + device_mesh_shape: (4, 4), DistSpec: + shard_sequence: S01,R,R + device_mesh_shape: (4, 4)] + comm_action_sequence: [CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), + CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 0), + CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1)] + total_cost: 12294.402000000002 ''' - MAX_TRANSFORM_STEPS = 10 + MAX_TRANSFORM_STEPS = 20 total_cost = 0 total_steps = 0 transform_path = [] + comm_action_sequence = [] + spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence)) + self.cached_spec_pairs_transform_path[spec_pairs] = (None, None) + + # We do nothing if the sharding spec is all the same. + if source_spec.sharding_sequence_difference(target_spec) == 0: + self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence) + return (transform_path, comm_action_sequence, total_cost) + temp_sharding_spec = deepcopy(source_spec) transform_path.append(temp_sharding_spec) + # To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms while total_steps <= MAX_TRANSFORM_STEPS: - valid_transform_spec_dict = get_all_one_step_transform_spec(temp_sharding_spec) - best_difference_score = 0 - for sharding_spec, cost in valid_transform_spec_dict.items(): - if no_difference(sharding_spec, target_spec): + valid_transform_spec_dict = self.get_all_one_step_transform_spec(temp_sharding_spec, total_cost) + best_difference_score = math.inf + + for sharding_spec, info_pairs in valid_transform_spec_dict.items(): + comm_spec, cost = info_pairs + spec_difference = sharding_spec.sharding_sequence_difference(target_spec) + + if spec_difference == 0: total_cost += cost transform_path.append(sharding_spec) - return (transform_path, total_cost) - if difference(sharding_spec, target_spec) > best_difference_score: + comm_action_sequence.append(comm_spec) + self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence) + return (transform_path, comm_action_sequence, total_cost) + + if spec_difference < best_difference_score: temp_sharding_spec = deepcopy(sharding_spec) temp_cost = cost + temp_comm_spec = deepcopy(comm_spec) + best_difference_score = spec_difference + transform_path.append(temp_sharding_spec) + comm_action_sequence.append(temp_comm_spec) total_cost += temp_cost - return (transform_path, total_cost) + total_steps += 1 + + raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.") diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py index 7fa68b05b..15643b5a2 100644 --- a/colossalai/tensor/sharding_spec.py +++ b/colossalai/tensor/sharding_spec.py @@ -1,4 +1,15 @@ +import torch from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator +from copy import deepcopy +from enum import Enum +from functools import reduce +import operator + +ALLGATHER_COST = 20 +SHARD_COST = 5 +STEP_PENALTY = 6 +NAN = 'nan' class _DimSpec: @@ -15,6 +26,7 @@ class _DimSpec: def __init__(self, shard_list): self.is_replica = len(shard_list) == 0 self.shard_list = shard_list + self.build_difference_2d_dict() def __eq__(self, other): return str(self) == str(other) @@ -27,11 +39,101 @@ class _DimSpec: target += str(dim) return target + def _convert_str_to_shard_list(self, str_spec): + ''' + Conver str_spec into shard_list. + + Argument: + str_spec(str): dim spec in str type. + ''' + + if str_spec == 'R': + return [] + if str_spec == 'S0': + return [0] + if str_spec == 'S1': + return [1] + if str_spec == 'S01': + return [0, 1] + + def build_difference_2d_dict(self): + ''' + Build a difference maping for 2D device mesh case. It will be used to + compute the difference between DimSpec pairs. + ''' + + source_spec_list = ['R', 'S0', 'S1', 'S01'] + target_spec_list = ['R', 'S0', 'S1', 'S01'] + difference_dict = {} + for source_spec in source_spec_list: + for target_spec in target_spec_list: + legal_sharding_dims = [] + spec_pair = (deepcopy(source_spec), deepcopy(target_spec)) + source_shard_list = self._convert_str_to_shard_list(source_spec) + target_shard_list = self._convert_str_to_shard_list(target_spec) + + # source same as target + if source_shard_list == target_shard_list: + difference = 0 + + # all_gather(source) -> target + elif len(source_shard_list + ) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list: + difference = ALLGATHER_COST + + # shard(source) -> target + elif len(source_shard_list) == len( + target_shard_list) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[ + -1] not in source_shard_list: + difference = SHARD_COST + + # S1 -> S0 or S0 -> S1 + elif len(source_shard_list) == len(target_shard_list): + # source -> R -> target + difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + + # R -> S01 + elif len(source_shard_list) == len(target_shard_list) - 2: + difference = SHARD_COST + STEP_PENALTY + SHARD_COST + + # S01 -> R + elif len(source_shard_list) == len(target_shard_list) + 2: + difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + + # S1 -> S01 + elif len(source_shard_list) == len(target_shard_list) - 1: + difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + STEP_PENALTY + SHARD_COST + + # S01 -> S1 + elif len(source_shard_list) == len(target_shard_list) + 1: + difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + STEP_PENALTY + SHARD_COST + + else: + difference = NAN + difference_dict[spec_pair] = difference + + self.difference_dict = difference_dict + def difference(self, other): ''' - This function is temporarily NOT implemented, it will be codesigned with ShapeConsistency feature. + The difference between two _DimSpec. + + Argument: + other(_DimSpec): the dim spec to compare with. + + Return: + difference(int): the difference between two _DimSpec. + + Example: + dim_spec = _DimSpec([0]) + other_dim_spec = _DimSpec([0, 1]) + print(dim_spec.difference(other_dim_spec)) + + Output: + 5 ''' - pass + difference = self.difference_dict[(str(self), str(other))] + return difference class ShardingSpec: @@ -43,8 +145,9 @@ class ShardingSpec: Argument: device_mesh(DeviceMesh): A logical view of a physical mesh. entire_shape(torch.Size): The entire shape of tensor before sharded. - dim_partition_dict(Dict[int, List[int]]): The key is the dimension of tensor to be sharded, + dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded, and the value of the key decribe which logical axis will be sharded in that dimension. + sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1]. ''' def __init__(self, device_mesh, entire_shape, dim_partition_dict=None, sharding_sequence=None): @@ -79,12 +182,18 @@ class ShardingSpec: f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.") def convert_dict_to_shard_sequence(self): + ''' + Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence. + ''' sharding_sequence = [_DimSpec([])] * len(self.entire_shape) for dim, shard_list in self.dim_partition_dict.items(): sharding_sequence[dim] = _DimSpec(shard_list) self.sharding_sequence = sharding_sequence def convert_shard_sequence_to_dict(self): + ''' + Convert sharding_sequence into dim_partition_dict. + ''' new_dim_partition_dict = {} for index, dim_spec in enumerate(self.sharding_sequence): if not dim_spec.is_replica: @@ -95,6 +204,45 @@ class ShardingSpec: def sharding_sequence_difference(self, other): ''' - This function is temporarily NOT implemented, it will be codesigned with ShapeConsistency feature. + This function is a naive version of difference computation. It just simply accumulates difference every dimension between the + pair of sharding sequence. + + Example: + dim_partition_dict = {0: [0, 1]} + # DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4) + sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) + dim_partition_dict_to_compare = {0: [0], 1: [1]} + # DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4) + sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare) + print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare)) + + Output: + 25 + + Argument: + other(ShardingSpec): The ShardingSpec to compared with. + + Return: + difference(int): Difference between two ShardingSpec. ''' - pass + assert len(self.sharding_sequence) == len( + other.sharding_sequence), f'Cannot compare difference for two sharding specs with different length.' + difference = 0 + for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence, other.sharding_sequence): + difference += orig_dim_spec.difference(other_dim_spec) + return difference + + def get_sharded_shape_per_device(self): + + sharded_shape = list(self.entire_shape) + for dim, shard_list in self.dim_partition_dict.items(): + mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list] + shard_partitions = reduce(operator.mul, mesh_list, 1) + assert sharded_shape[ + dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.' + sharded_shape[dim] //= shard_partitions + return torch.Size(sharded_shape) diff --git a/colossalai/tensor/utils.py b/colossalai/tensor/utils.py index d154ec7a7..b2eda5a8d 100644 --- a/colossalai/tensor/utils.py +++ b/colossalai/tensor/utils.py @@ -5,6 +5,90 @@ import torch.nn as nn from colossalai.tensor.colo_tensor import ColoTensor +def all_gather_simulator(target_pair): + ''' + Simulating all-gather operation, analyze the communication cost + and simulate the influence of the DimSpec. + + We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed. + Therefore, all gather operation just remove the last element in shard list, + e.g.: + all-gather(S01) -> S0 + + Argument: + target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, + and the second element decribes which logical axis will be sharded in that dimension. + ''' + _, shard_list = target_pair + new_shard_list = shard_list[:-1] + + return new_shard_list + + +def all_to_all_simulator(f_target_pair, b_target_pair): + ''' + Simulating all-to-all operation, analyze the communication cost + and simulate the influence of the DimSpec. + + We BANNED all representations which shard_list in decreasing order, + such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed. + Therefore, if the behind shard_list is not None, we just extend it to the front shard_list. + Argument: + target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, + and the second element decribes which logical axis will be sharded in that dimension. + e.g.: + all-to-all(S0, S1) -> [S01, R] + all-to-all(S0, R) -> [R, S0] + Otherwise, we extend the front shard_list to behind. + e.g.: + all-to-all(R, S1) -> [S1, R] + + Argument: + target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, + and the second element decribes which logical axis will be sharded in that dimension. + ''' + _, f_shard_list = f_target_pair + _, b_shard_list = b_target_pair + if not len(b_shard_list): + b_shard_list.extend(f_shard_list) + f_shard_list = [] + else: + f_shard_list.extend(b_shard_list) + b_shard_list = [] + + return f_shard_list, b_shard_list + + +def shard_simulator(target_pair, legal_sharding_dims): + ''' + Simulating shard operation, analyze the communication cost(always ZERO) + and simulate the influence of the DimSpec. + + We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed. + In addition, We BANNED all representations which shard_list in decreasing order, + such as S10, so shard(S0) -> S10 is NOT allowed. + Therefore, for the R dimension, we could just append any legal sharding dim on it. + e.g.: + shard(R) -> S0 + For the S dimension, we need to make sure the shard_list after sharding still keep rising order. + e.g: + shard(S0) -> S01 + + Argument: + target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, + and the second element decribes which logical axis will be sharded in that dimension. + ''' + _, shard_list = target_pair + shard_list_list = [] + for dim in legal_sharding_dims: + if len(shard_list) != 0 and dim <= shard_list[-1]: + continue + new_shard_list = shard_list + [dim] + shard_list_list.append(new_shard_list) + + return shard_list_list + + # The function is credited to PyTorch Team def named_params_with_colotensor( module: nn.Module, diff --git a/tests/test_tensor/test_shape_consistency.py b/tests/test_tensor/test_shape_consistency.py index 634c5b85b..8b1578165 100644 --- a/tests/test_tensor/test_shape_consistency.py +++ b/tests/test_tensor/test_shape_consistency.py @@ -1,29 +1,32 @@ -from colossalai.tensor.shape_consistency import ShapeConsistencyManager +from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern import torch from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec from colossalai.device.device_mesh import DeviceMesh +physical_mesh_id = torch.arange(0, 16).reshape(2, 8) +mesh_shape = (4, 4) +# [[0, 1, 2, 3], +# [4, 5, 6, 7], +# [8, 9, 10,11], +# [12,13,14,15]] +device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) +entire_shape = torch.Size((64, 32, 16)) +shape_consistency_manager = ShapeConsistencyManager() + + +def test_one_step_transform(): -def test_shape_consistency(): - physical_mesh_id = torch.arange(0, 16).reshape(2, 8) - mesh_shape = (4, 4) - # [[0, 1, 2, 3], - # [4, 5, 6, 7], - # [8, 9, 10,11], - # [12,13,14,15]] - device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) - entire_shape = torch.Size((4, 8, 6)) dim_partition_dict = {0: [0], 1: [1]} # DistSpec: # shard_sequence: S0,S1,R # device_mesh_shape: (4, 4) sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) - shape_consistency_manager = ShapeConsistencyManager() + # {DistSpec: # shard_sequence: R,S1,R - # device_mesh_shape: (4, 4): 0, DistSpec: + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0), 0), DistSpec: # shard_sequence: S0,R,R - # device_mesh_shape: (4, 4): 0} + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), 0)} rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0) assert '[R, S1, R]' in [ @@ -39,12 +42,12 @@ def test_shape_consistency(): # device_mesh_shape: (4, 4) sharding_spec_all2all = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_all2all) # {DistSpec: - # shard_sequence: S01,R,R - # device_mesh_shape: (4, 4): 0, DistSpec: - # shard_sequence: R,S1,S0 - # device_mesh_shape: (4, 4): 0, DistSpec: - # shard_sequence: S0,R,S1 - # device_mesh_shape: (4, 4): 0} + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 1), 0), DistSpec: + # shard_sequence: R,S1,S0 + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:0, shard_dim:2, logical_process_axis: 0), 0), DistSpec: + # shard_sequence: S0,R,S1 + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:2, logical_process_axis: 1), 0)} rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec_all2all, 0) assert '[S01, R, R]' in [ @@ -63,12 +66,12 @@ def test_shape_consistency(): # device_mesh_shape: (4, 4) sharding_spec_shard = ShardingSpec(device_mesh, entire_shape, dim_partition_shard) # {DistSpec: - # shard_sequence: S01,R,R - # device_mesh_shape: (4, 4): 0, DistSpec: - # shard_sequence: S0,S1,R - # device_mesh_shape: (4, 4): 0, DistSpec: - # shard_sequence: S0,R,S1 - # device_mesh_shape: (4, 4): 0} + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1), 0), DistSpec: + # shard_sequence: S0,S1,R + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1), 0), DistSpec: + # shard_sequence: S0,R,S1 + # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:2, logical_process_axis:1), 0)} rst_dict_shard = shape_consistency_manager.get_all_shard_spec(sharding_spec_shard, 0) assert '[S01, R, R]' in [ @@ -82,5 +85,48 @@ def test_shape_consistency(): ] +def test_shape_consistency(): + dim_partition_source = {1: [0, 1]} + dim_partition_target = {0: [0, 1]} + + # DistSpec: + # shard_sequence: R,S01,R + # device_mesh_shape: (4, 4) + sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source) + + # DistSpec: + # shard_sequence: S01,R,R + # device_mesh_shape: (4, 4) + sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target) + + transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency( + sharding_spec_source, sharding_spec_target) + + transform_path_str = '->'.join([str(sharding_spec.sharding_sequence) for sharding_spec in transform_path]) + assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]' + + # all-gather(S01) -> S0 + assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.ALLGATHER + assert comm_action_sequence[0].gather_dim == 1 + assert comm_action_sequence[0].logical_process_axis == 1 + + # all-to-all(R, S0) -> [S0, R] + assert comm_action_sequence[1].comm_pattern == CollectiveCommPattern.ALLTOALL + assert comm_action_sequence[1].gather_dim == 1 + assert comm_action_sequence[1].shard_dim == 0 + assert comm_action_sequence[1].logical_process_axis == 0 + + # shard(S0) -> [S01] + assert comm_action_sequence[2].comm_pattern == CollectiveCommPattern.SHARD + assert comm_action_sequence[2].shard_dim == 0 + assert comm_action_sequence[2].logical_process_axis == 1 + + assert shape_consistency_manager.cached_spec_pairs_transform_path[('[R, S01, R]', + '[S01, R, R]')][0] == transform_path + assert shape_consistency_manager.cached_spec_pairs_transform_path[('[R, S01, R]', + '[S01, R, R]')][1] == comm_action_sequence + + if __name__ == '__main__': + test_one_step_transform() test_shape_consistency()