import torch from colossalai.device.device_mesh import DeviceMesh from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator from copy import deepcopy from enum import Enum from functools import reduce import operator ALLGATHER_COST = 20 SHARD_COST = 5 STEP_PENALTY = 6 NAN = 'nan' class _DimSpec: ''' Sharding spec for single dimension of the sharded tensor decribe the sharding dimension of logical device mesh and give a method to compute the difference between them. This class is used internally in ShardingSpec. Argument: shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type. Otherwise, the element in shard_list means the data will be sharded in that dimension. ''' def __init__(self, shard_list): self.is_replica = len(shard_list) == 0 self.shard_list = shard_list self.build_difference_2d_dict() def __eq__(self, other): return str(self) == str(other) def __repr__(self): if self.is_replica: return 'R' target = 'S' for dim in self.shard_list: target += str(dim) return target def _convert_str_to_shard_list(self, str_spec): ''' Conver str_spec into shard_list. Argument: str_spec(str): dim spec in str type. ''' if str_spec == 'R': return [] if str_spec == 'S0': return [0] if str_spec == 'S1': return [1] if str_spec == 'S01': return [0, 1] def build_difference_2d_dict(self): ''' Build a difference maping for 2D device mesh case. It will be used to compute the difference between DimSpec pairs. ''' source_spec_list = ['R', 'S0', 'S1', 'S01'] target_spec_list = ['R', 'S0', 'S1', 'S01'] difference_dict = {} for source_spec in source_spec_list: for target_spec in target_spec_list: legal_sharding_dims = [] spec_pair = (deepcopy(source_spec), deepcopy(target_spec)) source_shard_list = self._convert_str_to_shard_list(source_spec) target_shard_list = self._convert_str_to_shard_list(target_spec) # source same as target if source_shard_list == target_shard_list: difference = 0 # all_gather(source) -> target elif len(source_shard_list ) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list: difference = ALLGATHER_COST # shard(source) -> target elif len(source_shard_list) == len( target_shard_list) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[ -1] not in source_shard_list: difference = SHARD_COST # S1 -> S0 or S0 -> S1 elif len(source_shard_list) == len(target_shard_list): # source -> R -> target difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST # R -> S01 elif len(source_shard_list) == len(target_shard_list) - 2: difference = SHARD_COST + STEP_PENALTY + SHARD_COST # S01 -> R elif len(source_shard_list) == len(target_shard_list) + 2: difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST # S1 -> S01 elif len(source_shard_list) == len(target_shard_list) - 1: difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + STEP_PENALTY + SHARD_COST # S01 -> S1 elif len(source_shard_list) == len(target_shard_list) + 1: difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + STEP_PENALTY + SHARD_COST else: difference = NAN difference_dict[spec_pair] = difference self.difference_dict = difference_dict def difference(self, other): ''' The difference between two _DimSpec. Argument: other(_DimSpec): the dim spec to compare with. Return: difference(int): the difference between two _DimSpec. Example: dim_spec = _DimSpec([0]) other_dim_spec = _DimSpec([0, 1]) print(dim_spec.difference(other_dim_spec)) Output: 5 ''' difference = self.difference_dict[(str(self), str(other))] return difference class ShardingSpec: ''' Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong to, the entire shape of the tensor before sharded, and the sharding sequence looks like [R, R, S0, S1]. Argument: device_mesh(DeviceMesh): A logical view of a physical mesh. entire_shape(torch.Size): The entire shape of tensor before sharded. dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded, and the value of the key decribe which logical axis will be sharded in that dimension. sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1]. ''' def __init__(self, device_mesh, entire_shape, dim_partition_dict=None, sharding_sequence=None): self.device_mesh = device_mesh self.entire_shape = entire_shape self.dim_partition_dict = dim_partition_dict self.sharding_sequence = sharding_sequence if self.sharding_sequence is None: self.convert_dict_to_shard_sequence() elif self.dim_partition_dict is None: self.convert_shard_sequence_to_dict() self._sanity_check() def __repr__(self): res_list = ["DistSpec:"] res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence)) res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.mesh_shape}") return ' '.join(res_list) def _sanity_check(self): ''' In sanity check, we need make sure all axes in logical device mesh only be used once. ''' dim_check_list = [i for i in range(self.device_mesh.logical_mesh_id.dim())] for dim, shard_list in self.dim_partition_dict.items(): for element in shard_list: if element in dim_check_list: dim_check_list.remove(element) else: raise ValueError( f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.") def convert_dict_to_shard_sequence(self): ''' Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence. ''' sharding_sequence = [_DimSpec([])] * len(self.entire_shape) for dim, shard_list in self.dim_partition_dict.items(): sharding_sequence[dim] = _DimSpec(shard_list) self.sharding_sequence = sharding_sequence def convert_shard_sequence_to_dict(self): ''' Convert sharding_sequence into dim_partition_dict. ''' new_dim_partition_dict = {} for index, dim_spec in enumerate(self.sharding_sequence): if not dim_spec.is_replica: if index not in new_dim_partition_dict: new_dim_partition_dict[index] = [] new_dim_partition_dict[index].extend(dim_spec.shard_list) self.dim_partition_dict = new_dim_partition_dict def sharding_sequence_difference(self, other): ''' This function is a naive version of difference computation. It just simply accumulates difference every dimension between the pair of sharding sequence. Example: dim_partition_dict = {0: [0, 1]} # DistSpec: # shard_sequence: S01,R,R # device_mesh_shape: (4, 4) sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) dim_partition_dict_to_compare = {0: [0], 1: [1]} # DistSpec: # shard_sequence: S0,S1,R # device_mesh_shape: (4, 4) sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare) print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare)) Output: 25 Argument: other(ShardingSpec): The ShardingSpec to compared with. Return: difference(int): Difference between two ShardingSpec. ''' assert len(self.sharding_sequence) == len( other.sharding_sequence), f'Cannot compare difference for two sharding specs with different length.' difference = 0 for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence, other.sharding_sequence): difference += orig_dim_spec.difference(other_dim_spec) return difference def get_sharded_shape_per_device(self): sharded_shape = list(self.entire_shape) for dim, shard_list in self.dim_partition_dict.items(): mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list] shard_partitions = reduce(operator.mul, mesh_list, 1) assert sharded_shape[ dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.' sharded_shape[dim] //= shard_partitions return torch.Size(sharded_shape)