import math from copy import deepcopy from dataclasses import dataclass from typing import Dict, List, Tuple import numpy as np import torch from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem from colossalai.context.singleton_meta import SingletonMeta from colossalai.tensor.sharding_spec import ShardingSpec, ShardingSpecException from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, mix_gather_simulator, shard_simulator from .comm_spec import * __all__ = ['ShapeConsistencyManager', 'ShapeConsistencyOptions', 'set_shape_consistency_options'] @dataclass class ShapeConsistencyOptions: """ ShapeConsistencyOptions is a dataclass which specifies the preferences for shape consistency. """ # TODO: shape consistency option is not implemented yet pass def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec) -> torch.Tensor: shape_consistency_manager = ShapeConsistencyManager() global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {}) with torch.no_grad(): global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(distributed_tensor, sharding_spec, global_sharding_spec) return global_tensor def set_shape_consistency_options(options: ShapeConsistencyOptions): """ Configure the shape consistency manager via function call. """ manager = ShapeConsistencyManager() manager.options = options class ShapeConsistencyManager(metaclass=SingletonMeta): def __init__(self): self._options = None self._forward_only = False self.total_communication_cost = 0 self.total_transform_steps = 0 self.cached_spec_pairs_transform_path = {} @property def options(self): return self._options @options.setter def options(self, options_: ShapeConsistencyOptions): assert isinstance(options_, ShapeConsistencyOptions) self._options = options_ @property def forward_only(self): return self._forward_only @forward_only.setter def forward_only(self, value): assert isinstance(value, bool) self._forward_only = value def get_all_all_gather_spec(self, source_spec: ShardingSpec, orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: ''' Get all valid sharding specs from source_spec with single all-gather operation, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. For the all-gather operation, we just care about the S dimension. Argument: source_spec(ShardingSpec): the ShardingSpec of the source_spec. orig_cost(Dict[str, float]): the original communication cost before this operation. Return: valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-gather operation. Example: dim_partition_dict = {0: [0], 1: [1]} # DistSpec: # shard_sequence: S0,S1,R # device_mesh_shape: (4, 4) sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) shape_consistency_manager = ShapeConsistencyManager() rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0}) print(rst_dict) Output: {DistSpec: shard_sequence: R,S1,R device_mesh_shape: (4, 4): 0, DistSpec: shard_sequence: S0,R,R device_mesh_shape: (4, 4): 0} ''' valid_spec_dict = {} comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD for target_pair in source_spec.dim_partition_dict.items(): shard_list = all_gather_simulator(target_pair) index = target_pair[0] new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) # We won't add empty list into dim_partition_dict # The key will be popped if the related shard_list is empty if shard_list: new_dim_partition_dict[index] = shard_list else: new_dim_partition_dict.pop(index) # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec gather_dim = index logical_process_axis = target_pair[1][-1] comm_spec = CommSpec( comm_pattern, sharding_spec=source_spec, gather_dim=gather_dim, # shard_dim will be used during backward shard_dim=gather_dim, logical_process_axis=logical_process_axis, forward_only=self.forward_only) # compute the communication cost with CommSpec cost_dict = comm_spec.get_comm_cost() # generate new sharding spec try: new_sharding_spec = ShardingSpec(source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict) for phase, cost in cost_dict.items(): cost_dict[phase] = cost + orig_cost_dict[phase] valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) except ShardingSpecException: pass return valid_spec_dict def get_all_all_to_all_spec(self, source_spec: ShardingSpec, orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: ''' Get all valid sharding specs from source_spec with single all-to-all operation, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. For the all-to-all operation, we just care about the pairs containing S dimension. Argument: source_spec(ShardingSpec): the ShardingSpec of the source_spec. orig_cost(Dict[str, float]): the original communication cost before this operation. Return: valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. Example: dim_partition_dict = {0: [0], 1: [1]} # DistSpec: # shard_sequence: S0,S1,R # device_mesh_shape: (4, 4) sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) shape_consistency_manager = ShapeConsistencyManager() rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0}) print(rst_dict) Output: {DistSpec: shard_sequence: S01,R,R device_mesh_shape: (4, 4): 0, DistSpec: shard_sequence: R,S1,S0 device_mesh_shape: (4, 4): 0, DistSpec: shard_sequence: S0,R,S1 device_mesh_shape: (4, 4): 0} ''' valid_spec_dict = {} comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD tensor_dims = len(source_spec.entire_shape) for f_index in range(tensor_dims - 1): for b_index in range(f_index + 1, tensor_dims): # skip (R, R) cases if f_index not in source_spec.dim_partition_dict and b_index not in source_spec.dim_partition_dict: continue else: if f_index in source_spec.dim_partition_dict: # skip (S01, R) -> (R, S01) is NOT allowed if len(source_spec.dim_partition_dict[f_index]) >= 2: continue f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index])) else: f_target_pair = (f_index, []) if b_index in source_spec.dim_partition_dict: # skip (R, S01) -> (S01, R) is NOT allowed if len(source_spec.dim_partition_dict[b_index]) >= 2: continue b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index])) else: b_target_pair = (b_index, []) # skip (S1, S0) -> S10 if f_target_pair[1] and b_target_pair[1] and f_target_pair[1][0] >= b_target_pair[1][0]: continue f_shard_list, b_shard_list = all_to_all_simulator(f_target_pair, b_target_pair) f_index = f_target_pair[0] b_index = b_target_pair[0] # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec if len(f_shard_list) < len(f_target_pair[1]): gather_dim = f_index shard_dim = b_index logical_process_axis = f_target_pair[1][-1] else: gather_dim = b_index shard_dim = f_index logical_process_axis = b_target_pair[1][-1] comm_spec = CommSpec(comm_pattern, sharding_spec=source_spec, gather_dim=gather_dim, shard_dim=shard_dim, logical_process_axis=logical_process_axis, forward_only=self.forward_only) # compute the communication cost with CommSpec cost_dict = comm_spec.get_comm_cost() new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) # We won't add empty list into dim_partition_dict # The key will be popped if the related shard_list is empty if f_shard_list: new_dim_partition_dict[f_index] = f_shard_list else: new_dim_partition_dict.pop(f_index) if b_shard_list: new_dim_partition_dict[b_index] = b_shard_list else: new_dim_partition_dict.pop(b_index) # generate new sharding spec try: new_sharding_spec = ShardingSpec(source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict) for phase, cost in cost_dict.items(): cost_dict[phase] = cost + orig_cost_dict[phase] valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) except ShardingSpecException: pass return valid_spec_dict def get_all_shard_spec(self, source_spec: ShardingSpec, orig_cost_dict): ''' Get all valid sharding specs from source_spec with single shard operation, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. For the sharding operation, we just care about legal sharding dimensions. Argument: source_spec(ShardingSpec): the ShardingSpec of the source_spec. orig_cost(float): the original communication cost before this operation. Return: valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. Example: dim_partition_dict = {0: [0]} # DistSpec: # shard_sequence: S0,R,R # device_mesh_shape: (4, 4) sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) shape_consistency_manager = ShapeConsistencyManager() rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, {'forward': 0, 'backward': 0, 'total': 0}) print(rst_dict) Output: {DistSpec: shard_sequence: S01,R,R device_mesh_shape: (4, 4): 0, DistSpec: shard_sequence: S0,S1,R device_mesh_shape: (4, 4): 0, DistSpec: shard_sequence: S0,R,S1 device_mesh_shape: (4, 4): 0} ''' valid_spec_dict = {} comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD # legal sharding dims means the mesh_id is still available to use. legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.mesh_shape))] for dim, shard_list in source_spec.dim_partition_dict.items(): for element in shard_list: legal_sharding_dims.remove(element) if len(legal_sharding_dims) == 0: return valid_spec_dict tensor_dims = len(source_spec.entire_shape) for index in range(tensor_dims): if index not in source_spec.dim_partition_dict: shard_list_list = shard_simulator((index, []), legal_sharding_dims) else: shard_list_list = shard_simulator((index, source_spec.dim_partition_dict[index]), legal_sharding_dims) if not shard_list_list: continue for shard_list in shard_list_list: new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) new_dim_partition_dict[index] = shard_list # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec shard_dim = index logical_process_axis = shard_list[-1] comm_spec = CommSpec(comm_pattern, sharding_spec=source_spec, gather_dim=shard_dim, shard_dim=shard_dim, logical_process_axis=logical_process_axis, forward_only=self.forward_only) # compute the communication cost with CommSpec cost_dict = comm_spec.get_comm_cost() # generate new sharding spec try: new_sharding_spec = ShardingSpec(source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict) for phase, cost in cost_dict.items(): cost_dict[phase] = cost + orig_cost_dict[phase] valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) except ShardingSpecException: pass return valid_spec_dict def get_all_mix_gather_spec(self, source_spec: ShardingSpec, orig_cost_dict: Dict[str, float]) -> Dict[ShardingSpec, float]: ''' S0S1 -> RR S1S0 -> RR S01R -> RR RS01 -> RR ''' valid_spec_dict = {} comm_pathern = CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD tensor_dims = len(source_spec.entire_shape) for f_index in range(tensor_dims - 1): for b_index in range(f_index + 1, tensor_dims): if (f_index not in source_spec.dim_partition_dict) and (b_index not in source_spec.dim_partition_dict): continue else: if f_index in source_spec.dim_partition_dict: # skip (S10, R) -> (R, R) if len(f_target_pair[1]) == 2 and f_target_pair[1][0] >= f_target_pair[1][1]: continue f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index])) else: f_target_pair = (f_index, []) if b_index in source_spec.dim_partition_dict: # skip (R, S10) -> (R, R) if len(b_target_pair[1]) == 2 and b_target_pair[1][0] >= b_target_pair[1][1]: continue b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index])) else: b_target_pair = (b_index, []) gather_dim, logical_process_axes = mix_gather_simulator(f_target_pair, b_target_pair) comm_spec = CommSpec(comm_pathern, sharding_spec=source_spec, gather_dim=gather_dim, logical_process_axis=logical_process_axes, forward_only=self.forward_only, mix_gather=True) cost_dict = comm_spec.get_comm_cost() new_dim_partition_dict = {} # generate new sharding spec try: new_sharding_spec = ShardingSpec(source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict) for phase, cost in cost_dict.items(): cost_dict[phase] = cost + orig_cost_dict[phase] valid_spec_dict[new_sharding_spec] = (comm_spec, cost_dict) except ShardingSpecException: pass return valid_spec_dict def get_all_one_step_transform_spec(self, source_spec: ShardingSpec, orig_cost_dict) -> Dict[ShardingSpec, float]: ''' Get all valid sharding specs from source_spec with one step transform, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. Note: all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before, and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive, we could safely put them together. Argument: source_spec(ShardingSpec): the ShardingSpec of the source_spec. orig_cost(float): the original communication cost before this operation. Return: valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. ''' valid_spec_dict = {} valid_spec_dict.update(self.get_all_all_gather_spec(source_spec, orig_cost_dict)) valid_spec_dict.update(self.get_all_all_to_all_spec(source_spec, orig_cost_dict)) valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost_dict)) return valid_spec_dict def mem_cost(self, comm_action_sequence: List[CommSpec]) -> TrainCycleItem: """memory cost of the communication action sequence Args: comm_action_sequence (List[CommSpec]): list of communication actions Returns: TrainCycleItem: memory (numel) cost of such comm_action_sequence """ def compute_shape(sharding_spec: ShardingSpec): shape = sharding_spec.entire_shape new_shape = [] for dim, shard in sharding_spec.dim_partition_dict.items(): new_shape.append(shape[dim] // len(shard)) return new_shape def gather_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): """analyze all_gather memory footprint all_gather will allocate memory for the output tensor, and there will be temp memory for all_gather operation, which is twice the size of output tensor Args: comm_spec (CommSpec): input CommSpec discard_input (bool): whether to discard the input tensor alloc_numel (int): current allocated numel peak_numel (int): current peak numel """ input_shape = compute_shape(comm_spec.sharding_spec) input_numel = np.prod(input_shape) output_numel = input_numel * comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis] peak_numel = max(peak_numel, alloc_numel + output_numel * 2) alloc_numel += output_numel if discard_input: alloc_numel -= input_numel return alloc_numel, peak_numel def split_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): """analyze split memory footprint split will allocate memory for the output tensor if we don't apply shard on the first dimension of the input tensor. If we apply shard on the first dimension, the `torch.tensor.contiguous()` will not generate new tensor in this case, so no memory will be allocated. Args: comm_spec (CommSpec): input CommSpec discard_input (bool): whether to discard the input tensor alloc_numel (int): current allocated numel peak_numel (int): current peak numel """ shard_dim = comm_spec.shard_dim if shard_dim != 0: # if we don't shard the tensor on the first dimension, the split action will # generate a new tensor input_shape = compute_shape(comm_spec.sharding_spec) input_numel = np.prod(input_shape) output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis] alloc_numel += output_numel peak_numel = max(peak_numel, alloc_numel) if discard_input: alloc_numel -= input_numel else: # if we shard the tensor on the first dimension, the split action will not generate # a new tensor, and as it will preserve a reference to the input tensor, we could # override the discard_input option here # NOTE: this special case might fail in some weird cases, e.g. if we have three split # actions in the comm actions sequence, the first split action operate on the second dimension, # the second split action operate on the first dimension, and the third split action operate, again, # on the second dimension. Therefore, after the first two actions in the sequence, we will allocate # memory the same size as the output of first split action. However, the third split action will discard # the input tensor, and it actually should discard the tensor generated by the first split action, so in # the current memory estimation framework, we will overestimate the memory usage. But the above case is # kind of weird, and I think we could ignore it for now. pass return alloc_numel, peak_numel def reduce_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): """ a dummy function for reduce memory footprint analysis, as the reduce action doesn't allocate extra memory """ return alloc_numel, peak_numel def all2all_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): """analyze all_to_all memory footprint all_to_all will allocate memory for the output tensor, and temp memory of all_to_all action is twice the size of output tensor if we shard input tensor on the first dimension, otherwise the temp memory is three times the size of output tensor Args: comm_spec (CommSpec): input CommSpec discard_input (bool): whether to discard the input tensor alloc_numel (int): current allocated numel peak_numel (int): current peak numel """ input_shape = compute_shape(comm_spec.sharding_spec) input_numel = np.prod(input_shape) output_numel = input_numel shard_dim = comm_spec.shard_dim if shard_dim != 0: peak_numel = max(peak_numel, alloc_numel + output_numel * 3) else: peak_numel = max(peak_numel, alloc_numel + output_numel * 2) alloc_numel += output_numel if discard_input: alloc_numel -= input_numel return alloc_numel, peak_numel def identity_analysis(comm_spec: CommSpec, discard_input: bool, alloc_numel: int, peak_numel: int): """ a dummy function for identity memory footprint analysis, as the identity action doesn't allocate extra memory """ return alloc_numel, peak_numel pattern_to_func_dict = { CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: [gather_analysis, split_analysis], CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: [all2all_analysis, all2all_analysis], CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: [split_analysis, gather_analysis], CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD: [reduce_analysis, identity_analysis], CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: [identity_analysis, reduce_analysis], CollectiveCommPattern.MIXGATHER_FWD_SPLIT_BWD: [], } fwd_actions = [] bwd_actions = [] # construct forward and backward comm actions sequence for comm_spec in comm_action_sequence: comm_spec: CommSpec fwd_action, bwd_action = pattern_to_func_dict[comm_spec.comm_pattern] fwd_actions.append(fwd_action) bwd_actions.append(bwd_action) # analyze memory footprint of forward comm actions sequence fwd_alloc_numel = 0 fwd_peak_numel = 0 for idx, action_spec_pair in enumerate(zip(fwd_actions, comm_action_sequence)): # the first forward comm action will not discard input fwd_action, comm_spec = action_spec_pair fwd_alloc_numel, fwd_peak_numel = fwd_action(comm_spec, False, fwd_alloc_numel, fwd_peak_numel) if idx == 0 else fwd_action( comm_spec, True, fwd_alloc_numel, fwd_peak_numel) # analyze memory footprint for backward comm actions sequence bwd_alloc_numel = 0 bwd_peak_numel = 0 for idx, action_spec_pair in enumerate(zip(reversed(bwd_actions), reversed(comm_action_sequence))): bwd_action, comm_spec = action_spec_pair bwd_alloc_numel, bwd_peak_numel = bwd_action(comm_spec, False, bwd_alloc_numel, bwd_peak_numel) if idx == 0 else bwd_action( comm_spec, True, bwd_alloc_numel, bwd_peak_numel) fwd_mem = MemoryCost(activation=fwd_alloc_numel, temp=fwd_peak_numel - fwd_alloc_numel) bwd_mem = MemoryCost(activation=bwd_alloc_numel, temp=bwd_peak_numel - bwd_alloc_numel) total_mem = MemoryCost(activation=fwd_alloc_numel + bwd_alloc_numel) return TrainCycleItem(fwd_mem, bwd_mem, total_mem) def shape_consistency(self, source_spec: ShardingSpec, target_spec: ShardingSpec) -> Tuple[List[ShardingSpec], List[CommSpec], float]: ''' This method will find a path to transform source_spec to target_spec with a greedy algorithm. The basic idea is: Step1: Generate all one-step transform sequences from source_spec. Step2: Pick the 'best' sharding spec following the heuristic function. Step3: Repeat above steps until the source spec transform to target spec. During finding the transform path, commucation cost will be accumulated, and it will be finally used in auto parallel solver. Additionally, to avoid repeating the path search in runtime, we cached all solved path in auto parallel strategy building time, which could handle most of cases in runtime. Argument: source_spec(ShardingSpec): ShardingSpec of the source activation. target_spec(ShardingSpec): ShardingSpec of the target activation. Return: transform_path(List[ShardingSpec]): The transform path from source_spec to target_spec, it contains the source_spec and target_spec. comm_action_sequence(List[CommSpec]): Keep the communication operations to complete the shape consistency in order. total_cost(float): total cost to complete shape consistency transform. Example: dim_partition_source = {1: [0, 1]} dim_partition_target = {0: [0, 1]} # DistSpec: # shard_sequence: R,S01,R # device_mesh_shape: (4, 4) sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source) # DistSpec: # shard_sequence: S01,R,R # device_mesh_shape: (4, 4) sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target) transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(sharding_spec_source, sharding_spec_target) print(f'transform_path: {transform_path}') print(f'comm_action_sequence: {comm_action_sequence}') print(f'total_cost: {total_cost}') output: transform_path: [DistSpec: shard_sequence: R,S01,R device_mesh_shape: (4, 4), DistSpec: shard_sequence: R,S0,R device_mesh_shape: (4, 4), DistSpec: shard_sequence: S0,R,R device_mesh_shape: (4, 4), DistSpec: shard_sequence: S01,R,R device_mesh_shape: (4, 4)] comm_action_sequence: [CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 0), CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1)] total_cost: 12294.402000000002 ''' MAX_TRANSFORM_STEPS = 20 total_cost_dict = {'forward': 0, 'backward': 0, 'total': 0} total_steps = 0 transform_path = [] comm_action_sequence = [] spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence)) self.cached_spec_pairs_transform_path[spec_pairs] = (None, None) # We do nothing if the sharding spec is all the same. if source_spec.sharding_sequence_difference(target_spec) == 0: self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence) return (transform_path, comm_action_sequence, total_cost_dict) temp_sharding_spec = source_spec transform_path.append(temp_sharding_spec) # To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms while total_steps <= MAX_TRANSFORM_STEPS: valid_transform_spec_dict = self.get_all_one_step_transform_spec(temp_sharding_spec, total_cost_dict) best_difference_score = math.inf for sharding_spec, info_pairs in valid_transform_spec_dict.items(): comm_spec, cost_dict = info_pairs spec_difference = sharding_spec.sharding_sequence_difference(target_spec) if spec_difference == 0: for phase, cost in total_cost_dict.items(): total_cost_dict[phase] = cost + cost_dict[phase] transform_path.append(sharding_spec) comm_action_sequence.append(comm_spec) self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence) return (transform_path, comm_action_sequence, total_cost_dict) if spec_difference < best_difference_score: temp_sharding_spec = sharding_spec temp_cost_dict = cost_dict temp_comm_spec = comm_spec best_difference_score = spec_difference transform_path.append(temp_sharding_spec) comm_action_sequence.append(temp_comm_spec) for phase, cost in total_cost_dict.items(): total_cost_dict[phase] = cost + temp_cost_dict[phase] total_steps += 1 raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.") def apply(self, tensor_with_sharding_spec: torch.Tensor, target_spec: ShardingSpec) -> torch.Tensor: ''' Apply target_spec to tensor with source sharding spec, the transform path is generated by the shape_consistency method. Argument: tensor_with_sharding_spec (torch.Tensor): a tensor with source sharding spec to be transformed to the target spec. target_spec (ShardingSpec): The tensor transform processes will be directed by the target_spec. Example: physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) entire_shape = torch.Size((4, 2)) shape_consistency_manager = ShapeConsistencyManager() dim_partition_source = {0: [0]} dim_partition_target = {1: [0]} # DistSpec: # shard_sequence: S0,R # device_mesh_shape: (2, 2) sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source) # DistSpec: # shard_sequence: R,S0 # device_mesh_shape: (2, 2) sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target) if rank in (0, 1): sharded_tensor_0 = torch.zeros(2, 1) sharded_tensor_1 = torch.ones(2, 1) # tensor([[0., 1.], # [0., 1.]]) tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda() if rank in (2, 3): sharded_tensor_0 = torch.ones(2, 1) * 2 sharded_tensor_1 = torch.ones(2, 1) * 3 # tensor([[2., 3.], # [2., 3.]]) tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda() tensor_to_comm.sharding_spec = sharding_spec_source shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target) print(tensor_to_comm) Output in rank0 and rank2: tensor([[0.], [0.], [2.], [2.]]) Output in rank1 and rank3: tensor([[1.], [1.], [3.], [3.]]) ''' _, comm_action_sequence, _ = self.shape_consistency(tensor_with_sharding_spec.sharding_spec, target_spec) for comm_spec in comm_action_sequence: tensor_with_sharding_spec = comm_spec.covert_spec_to_action(tensor_with_sharding_spec) tensor_with_sharding_spec.sharding_spec = target_spec return tensor_with_sharding_spec def apply_for_autoparallel_runtime(self, tensor, source_spec, target_spec): _, comm_action_sequence, _ = self.shape_consistency(source_spec, target_spec) for comm_spec in comm_action_sequence: tensor = comm_spec.covert_spec_to_action(tensor) tensor.sharding_spec = target_spec return tensor