import torch from dataclasses import dataclass from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator from enum import Enum from copy import deepcopy from typing import Dict, List, Optional, Tuple, Union from colossalai.context.singleton_meta import SingletonMeta import torch.distributed as dist import math from functools import reduce import operator from torch.distributed import ReduceOp __all__ = [ 'CollectiveCommPattern', 'CommSpec', 'ShapeConsistencyManager', 'ShapeConsistencyOptions', 'set_shape_consistency_options' ] def _all_gather(tensor, comm_spec): ''' Implement all gather operation on device mesh based on information provided by comm_spec. ''' process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] for rank_list, process_group in process_groups_list: if dist.get_rank() in rank_list: tensor_list = [ torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]) ] tensor = tensor group = process_group dist.all_gather(tensor_list, tensor, group=group) output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() return output def _split(tensor, comm_spec): ''' Implement shard operation on device mesh based on information provided by comm_spec. ''' process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] for rank_list, _ in process_groups_list: if dist.get_rank() in rank_list: tensor = tensor dim = comm_spec.shard_dim length = tensor.shape[comm_spec.shard_dim] // len(rank_list) start = length * rank_list.index(dist.get_rank()) output = torch.narrow(tensor, dim, start, length) return output def _all_to_all(tensor, comm_spec): ''' Implement all to all operation on device mesh based on information provided by comm_spec. ''' process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] for rank_list, process_group in process_groups_list: if dist.get_rank() in rank_list: new_shape = list(tensor.shape) new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) new_shape = torch.Size(new_shape) output_tensor_list = [ torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) ] dim = comm_spec.shard_dim length = tensor.shape[comm_spec.shard_dim] // len(rank_list) input_tensor_list = [ torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) ] group = process_group dist.all_to_all(output_tensor_list, input_tensor_list, group) output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous() return output def _all_reduce(tensor, comm_spec): ''' Implement all reduce operation on device mesh based on information provided by comm_spec. ''' process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] for rank_list, process_group in process_groups_list: if dist.get_rank() in rank_list: dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group) return tensor class _ReduceGrad(torch.autograd.Function): """ A customized communication operation which forward is an identity operation, backward is all_reduce operation. Args: input_: input matrix. comm_spec: comm_spec will give information like process group, rank list, etc. """ @staticmethod def symbolic(graph, input_): return input_ @staticmethod def forward(ctx, input_, comm_spec): ctx.comm_spec = comm_spec return input_ @staticmethod def backward(ctx, grad_output): return _all_reduce(grad_output, ctx.comm_spec), None class _ReduceInput(torch.autograd.Function): """ A customized communication operation which forward is all_reduce operation, backward is an identity operation. Args: input_: input matrix. comm_spec: comm_spec will give information like process group, rank list, etc. """ @staticmethod def symbolic(graph, input_): return _all_reduce(input_) @staticmethod def forward(ctx, input_, comm_spec): return _all_reduce(input_, comm_spec) @staticmethod def backward(ctx, grad_output): return grad_output, None class _SplitForwardGatherBackward(torch.autograd.Function): """ A customized communication operation which forward is split operation, backward is an all gather operation. Args: input_: input matrix. comm_spec: comm_spec will give information like process group, rank list, etc. """ @staticmethod def symbolic(graph, input_): return _split(input_) @staticmethod def forward(ctx, input_, comm_spec): ctx.comm_spec = comm_spec return _split(input_, comm_spec) @staticmethod def backward(ctx, grad_output): return _all_gather(grad_output, ctx.comm_spec), None class _GatherForwardSplitBackward(torch.autograd.Function): """ A customized communication operation which forward is an all gather operation, backward is split operation. Args: input_: input matrix. comm_spec: comm_spec will give information like process group, rank list, etc. """ @staticmethod def symbolic(graph, input_): return _all_gather(input_) @staticmethod def forward(ctx, input_, comm_spec): ctx.comm_spec = comm_spec return _all_gather(input_, comm_spec) @staticmethod def backward(ctx, grad_output): return _split(grad_output, ctx.comm_spec), None class _AllToAll(torch.autograd.Function): """ A customized communication operation which forward is an all to all operation, backward is an all to all operation. Args: input_: input matrix. comm_spec: comm_spec will give information like process group, rank list, etc. """ @staticmethod def symbolic(graph, input_): return _all_to_all(input_) @staticmethod def forward(ctx, input_, comm_spec): output = _all_to_all(input_, comm_spec) comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern, sharding_spec=comm_spec.sharding_spec, gather_dim=comm_spec.shard_dim, shard_dim=comm_spec.gather_dim, logical_process_axis=comm_spec.logical_process_axis) ctx.comm_spec = comm_spec_for_backward return output @staticmethod def backward(ctx, grad_outputs): return _all_to_all(grad_outputs, ctx.comm_spec), None def reduce_grad(input_, comm_spec): return _ReduceGrad.apply(input_, comm_spec) def reduce_input(input_, comm_spec): return _ReduceInput.apply(input_, comm_spec) def split_forward_gather_backward(input_, comm_spec): return _SplitForwardGatherBackward.apply(input_, comm_spec) def gather_forward_split_backward(input_, comm_spec): return _GatherForwardSplitBackward.apply(input_, comm_spec) def all_to_all(input_, comm_spec): return _AllToAll.apply(input_, comm_spec) class CollectiveCommPattern(Enum): GATHER_FWD_SPLIT_BWD = 'gather_fwd_split_bwd' ALL2ALL_FWD_ALL2ALL_BWD = 'all2all_fwd_all2all_bwd' SPLIT_FWD_GATHER_BWD = 'split_fwd_gather_bwd' REDUCE_FWD_IDENTITY_BWD = 'all_reduce_fwd_identity_bwd' IDENTITY_FWD_ALLREDUCE_BWD = 'identity_fwd_all_reduce_bwd' class CommSpec: ''' Communication spec is used to record the communication action. It has two main functions: 1. Compute the communication cost which will be used in auto parallel solver. 2. Convert the communication spec to real action which will be used in runtime. It contains comm_pattern to determine the communication method, sharding_spec to determine the communication size, gather_dim and shard_dim to determine the buffer shape, and logical_process_axis Argument: comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec. sharding_spec(ShardingSpec): This is sharding spec of the tensor which will join the communication action. gather_dim(int, Optional): The gather_dim of the tensor will be gathered. shard_dim(int, Optional): The shard_dim of the tensor will be sharded. logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. ''' def __init__(self, comm_pattern, sharding_spec, gather_dim=None, shard_dim=None, logical_process_axis=None, forward_only=False): self.comm_pattern = comm_pattern self.sharding_spec = sharding_spec self.gather_dim = gather_dim self.shard_dim = shard_dim self.logical_process_axis = logical_process_axis self.forward_only = forward_only if isinstance(self.logical_process_axis, list): self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh self.logical_process_axis = 0 else: self.device_mesh = self.sharding_spec.device_mesh def __repr__(self): res_list = ["CommSpec:("] if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: res_list.append(f"comm_pattern:GATHER_FWD_SPLIT_BWD, ") res_list.append(f"gather_dim:{self.gather_dim}, ") res_list.append(f"logical_process_axis:{self.logical_process_axis})") elif self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: res_list.append(f"comm_pattern:ALL2ALL_FWD_ALL2ALL_BWD, ") res_list.append(f"gather_dim:{self.gather_dim}, ") res_list.append(f"shard_dim:{self.shard_dim}, ") res_list.append(f"logical_process_axis: {self.logical_process_axis})") elif self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: res_list.append(f"comm_pattern:SPLIT_FWD_GATHER_BWD, ") res_list.append(f"shard_dim:{self.shard_dim}, ") res_list.append(f"logical_process_axis:{self.logical_process_axis})") elif self.comm_pattern == CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD: res_list.append(f"comm_pattern:REDUCE_FWD_IDENTITY_BWD, ") res_list.append(f"logical_process_axis:{self.logical_process_axis})") elif self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: res_list.append(f"comm_pattern:IDENTITY_FWD_ALLREDUCE_BWD, ") res_list.append(f"logical_process_axis:{self.logical_process_axis})") return ''.join(res_list) def get_comm_cost(self): ''' For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to compute the communication cost. For shard operation, it is an on-chip operation, so the communication cost is zero. ''' comm_size = reduce(operator.mul, self.sharding_spec.get_sharded_shape_per_device(), 1) if self.comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: forward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) # give a tiny cost to shard backward_communication_cost = 10 if self.comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: forward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis) # grad should have same shape as input tensor # all to all operation has same logical process axis as forward. backward_communication_cost = self.device_mesh.all_to_all_cost(comm_size, self.logical_process_axis) if self.comm_pattern == CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD: forward_communication_cost = self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis) backward_communication_cost = 0 if self.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: forward_communication_cost = 0 backward_communication_cost = self.device_mesh.all_reduce_cost(comm_size, self.logical_process_axis) if self.comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: # give a tiny cost to shard forward_communication_cost = 10 backward_communication_cost = self.device_mesh.all_gather_cost(comm_size, self.logical_process_axis) try: if self.forward_only: total_communication_cost = forward_communication_cost else: total_communication_cost = forward_communication_cost + backward_communication_cost except: raise RuntimeError(f"Could not find a matching CollectiveCommPattern for {self.comm_pattern}.") return total_communication_cost def covert_spec_to_action(self, tensor): ''' Convert CommSpec into runtime action, implement real collection communication to target tensor. The collection communication action is directed by the CommSpec. Argument: tensor(torch.Tensor): Tensor stored in each device, which could be different in different ranks. ''' if self.comm_pattern in pattern_to_func_dict: tensor.data = pattern_to_func_dict[self.comm_pattern](tensor, self) else: tensor.data = tensor pattern_to_func_dict = { CollectiveCommPattern.GATHER_FWD_SPLIT_BWD: gather_forward_split_backward, CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD: all_to_all, CollectiveCommPattern.SPLIT_FWD_GATHER_BWD: split_forward_gather_backward, CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD: reduce_input, CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: reduce_grad, } @dataclass class ShapeConsistencyOptions: """ ShapeConsistencyOptions is a dataclass which specifies the preferences for shape consistency. """ # TODO: shape consistency option is not implemented yet pass def set_shape_consistency_options(options: ShapeConsistencyOptions): """ Configure the shape consistency manager via function call. """ manager = ShapeConsistencyManager() manager.options = options class ShapeConsistencyManager(metaclass=SingletonMeta): def __init__(self): self._options = None self._forward_only = False self.total_communication_cost = 0 self.total_transform_steps = 0 self.cached_spec_pairs_transform_path = {} @property def options(self): return self._options @options.setter def options(self, options_: ShapeConsistencyOptions): assert isinstance(options_, ShapeConsistencyOptions) self._options = options_ @property def forward_only(self): return self._forward_only @forward_only.setter def forward_only(self, value): assert isinstance(value, bool) self._forward_only = value def get_all_all_gather_spec(self, source_spec, orig_cost): ''' Get all valid sharding specs from source_spec with single all-gather operation, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. For the all-gather operation, we just care about the S dimension. Argument: source_spec(ShardingSpec): the ShardingSpec of the source_spec. orig_cost(float): the original communication cost before this operation. Return: valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-gather operation. Example: dim_partition_dict = {0: [0], 1: [1]} # DistSpec: # shard_sequence: S0,S1,R # device_mesh_shape: (4, 4) sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) shape_consistency_manager = ShapeConsistencyManager() rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0) print(rst_dict) Output: {DistSpec: shard_sequence: R,S1,R device_mesh_shape: (4, 4): 0, DistSpec: shard_sequence: S0,R,R device_mesh_shape: (4, 4): 0} ''' valid_spec_dict = {} comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD for target_pair in source_spec.dim_partition_dict.items(): shard_list = all_gather_simulator(target_pair) index = target_pair[0] new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) # We won't add empty list into dim_partition_dict # The key will be popped if the related shard_list is empty if shard_list: new_dim_partition_dict[index] = shard_list else: new_dim_partition_dict.pop(index) # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec gather_dim = index logical_process_axis = target_pair[1][-1] comm_spec = CommSpec( comm_pattern, sharding_spec=source_spec, gather_dim=gather_dim, # shard_dim will be used during backward shard_dim=gather_dim, logical_process_axis=logical_process_axis, forward_only=self.forward_only) # compute the communication cost with CommSpec cost = comm_spec.get_comm_cost() # generate new sharding spec new_sharding_spec = ShardingSpec(source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict) valid_spec_dict[new_sharding_spec] = (comm_spec, orig_cost + cost) return valid_spec_dict def get_all_all_to_all_spec(self, source_spec, orig_cost): ''' Get all valid sharding specs from source_spec with single all-to-all operation, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. For the all-to-all operation, we just care about the pairs containing S dimension. Argument: source_spec(ShardingSpec): the ShardingSpec of the source_spec. orig_cost(float): the original communication cost before this operation. Return: valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. Example: dim_partition_dict = {0: [0], 1: [1]} # DistSpec: # shard_sequence: S0,S1,R # device_mesh_shape: (4, 4) sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) shape_consistency_manager = ShapeConsistencyManager() rst_dict = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec, 0) print(rst_dict) Output: {DistSpec: shard_sequence: S01,R,R device_mesh_shape: (4, 4): 0, DistSpec: shard_sequence: R,S1,S0 device_mesh_shape: (4, 4): 0, DistSpec: shard_sequence: S0,R,S1 device_mesh_shape: (4, 4): 0} ''' valid_spec_dict = {} comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD tensor_dims = len(source_spec.entire_shape) for f_index in range(tensor_dims - 1): for b_index in range(f_index + 1, tensor_dims): # skip (R, R) cases if f_index not in source_spec.dim_partition_dict and b_index not in source_spec.dim_partition_dict: continue else: if f_index in source_spec.dim_partition_dict: # skip (S01, R) -> (R, S01) is NOT allowed if len(source_spec.dim_partition_dict[f_index]) >= 2: continue f_target_pair = (f_index, deepcopy(source_spec.dim_partition_dict[f_index])) else: f_target_pair = (f_index, []) if b_index in source_spec.dim_partition_dict: # skip (R, S01) -> (S01, R) is NOT allowed if len(source_spec.dim_partition_dict[b_index]) >= 2: continue b_target_pair = (b_index, deepcopy(source_spec.dim_partition_dict[b_index])) else: b_target_pair = (b_index, []) # skip (S1, S0) -> S10 if f_target_pair[1] and b_target_pair[1] and f_target_pair[1][0] >= b_target_pair[1][0]: continue f_shard_list, b_shard_list = all_to_all_simulator(f_target_pair, b_target_pair) f_index = f_target_pair[0] b_index = b_target_pair[0] # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec if len(f_shard_list) < len(f_target_pair[1]): gather_dim = f_index shard_dim = b_index logical_process_axis = f_target_pair[1][-1] else: gather_dim = b_index shard_dim = f_index logical_process_axis = b_target_pair[1][-1] comm_spec = CommSpec(comm_pattern, sharding_spec=source_spec, gather_dim=gather_dim, shard_dim=shard_dim, logical_process_axis=logical_process_axis, forward_only=self.forward_only) # compute the communication cost with CommSpec cost = comm_spec.get_comm_cost() new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) # We won't add empty list into dim_partition_dict # The key will be popped if the related shard_list is empty if f_shard_list: new_dim_partition_dict[f_index] = f_shard_list else: new_dim_partition_dict.pop(f_index) if b_shard_list: new_dim_partition_dict[b_index] = b_shard_list else: new_dim_partition_dict.pop(b_index) # generate new sharding spec new_sharding_spec = ShardingSpec(source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict) valid_spec_dict[new_sharding_spec] = (comm_spec, orig_cost + cost) return valid_spec_dict def get_all_shard_spec(self, source_spec, orig_cost): ''' Get all valid sharding specs from source_spec with single shard operation, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. For the sharding operation, we just care about legal sharding dimensions. Argument: source_spec(ShardingSpec): the ShardingSpec of the source_spec. orig_cost(float): the original communication cost before this operation. Return: valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. Example: dim_partition_dict = {0: [0]} # DistSpec: # shard_sequence: S0,R,R # device_mesh_shape: (4, 4) sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) shape_consistency_manager = ShapeConsistencyManager() rst_dict = shape_consistency_manager.get_all_shard_spec(sharding_spec, 0) print(rst_dict) Output: {DistSpec: shard_sequence: S01,R,R device_mesh_shape: (4, 4): 0, DistSpec: shard_sequence: S0,S1,R device_mesh_shape: (4, 4): 0, DistSpec: shard_sequence: S0,R,S1 device_mesh_shape: (4, 4): 0} ''' valid_spec_dict = {} comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD # legal sharding dims means the mesh_id is still available to use. legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.mesh_shape))] for dim, shard_list in source_spec.dim_partition_dict.items(): for element in shard_list: legal_sharding_dims.remove(element) if len(legal_sharding_dims) == 0: return valid_spec_dict tensor_dims = len(source_spec.entire_shape) for index in range(tensor_dims): if index not in source_spec.dim_partition_dict: shard_list_list = shard_simulator((index, []), legal_sharding_dims) else: shard_list_list = shard_simulator((index, source_spec.dim_partition_dict[index]), legal_sharding_dims) if not shard_list_list: continue for shard_list in shard_list_list: new_dim_partition_dict = deepcopy(source_spec.dim_partition_dict) new_dim_partition_dict[index] = shard_list # generate the CommSpec to record the action of source_sharding_spec->new_sharding_spec shard_dim = index logical_process_axis = shard_list[-1] comm_spec = CommSpec(comm_pattern, sharding_spec=source_spec, gather_dim=shard_dim, shard_dim=shard_dim, logical_process_axis=logical_process_axis, forward_only=self.forward_only) # compute the communication cost with CommSpec cost = comm_spec.get_comm_cost() # generate new sharding spec new_sharding_spec = ShardingSpec(source_spec.device_mesh, source_spec.entire_shape, dim_partition_dict=new_dim_partition_dict) valid_spec_dict[new_sharding_spec] = (comm_spec, orig_cost + cost) return valid_spec_dict def get_all_one_step_transform_spec(self, source_spec, orig_cost): ''' Get all valid sharding specs from source_spec with one step transform, and accumulate commucation cost on origin cost which will finally be used in auto sharding solver. Note: all-gather will eliminate a sharding dimension, all-to-all will keep sharding dimension same as before, and shard will add a sharding dimension. Therefore, the result of above operations are mutual exclusive, we could safely put them together. Argument: source_spec(ShardingSpec): the ShardingSpec of the source_spec. orig_cost(float): the original communication cost before this operation. Return: valid_spec_dict(Dict[ShardingSpec, float]): all valid sharding specs from source_spec with single all-to-all operation. ''' valid_spec_dict = {} valid_spec_dict.update(self.get_all_all_gather_spec(source_spec, orig_cost)) valid_spec_dict.update(self.get_all_all_to_all_spec(source_spec, orig_cost)) valid_spec_dict.update(self.get_all_shard_spec(source_spec, orig_cost)) return valid_spec_dict def shape_consistency(self, source_spec, target_spec): ''' This method will find a path to transform source_spec to target_spec with a greedy algorithm. The basic idea is: Step1: Generate all one-step transform sequences from source_spec. Step2: Pick the 'best' sharding spec following the heuristic function. Step3: Repeat above steps until the source spec transform to target spec. During finding the transform path, commucation cost will be accumulated, and it will be finally used in auto parallel solver. Additionally, to avoid repeating the path search in runtime, we cached all solved path in auto parallel strategy building time, which could handle most of cases in runtime. Argument: source_spec(ShardingSpec): ShardingSpec of the source activation. target_spec(ShardingSpec): ShardingSpec of the target activation. Return: transform_path(List[ShardingSpec]): The transform path from source_spec to target_spec, it contains the source_spec and target_spec. comm_action_sequence(List[CommSpec]): Keep the communication operations to complete the shape consistency in order. total_cost(float): total cost to complete shape consistency transform. Example: dim_partition_source = {1: [0, 1]} dim_partition_target = {0: [0, 1]} # DistSpec: # shard_sequence: R,S01,R # device_mesh_shape: (4, 4) sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source) # DistSpec: # shard_sequence: S01,R,R # device_mesh_shape: (4, 4) sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target) transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(sharding_spec_source, sharding_spec_target) print(f'transform_path: {transform_path}') print(f'comm_action_sequence: {comm_action_sequence}') print(f'total_cost: {total_cost}') output: transform_path: [DistSpec: shard_sequence: R,S01,R device_mesh_shape: (4, 4), DistSpec: shard_sequence: R,S0,R device_mesh_shape: (4, 4), DistSpec: shard_sequence: S0,R,R device_mesh_shape: (4, 4), DistSpec: shard_sequence: S01,R,R device_mesh_shape: (4, 4)] comm_action_sequence: [CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 0), CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1)] total_cost: 12294.402000000002 ''' MAX_TRANSFORM_STEPS = 20 total_cost = 0 total_steps = 0 transform_path = [] comm_action_sequence = [] spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence)) self.cached_spec_pairs_transform_path[spec_pairs] = (None, None) # We do nothing if the sharding spec is all the same. if source_spec.sharding_sequence_difference(target_spec) == 0: self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence) return (transform_path, comm_action_sequence, total_cost) temp_sharding_spec = source_spec transform_path.append(temp_sharding_spec) # To avoid dead loop, the loop will break after MAX_TRANSFORM_STEPS transforms while total_steps <= MAX_TRANSFORM_STEPS: valid_transform_spec_dict = self.get_all_one_step_transform_spec(temp_sharding_spec, total_cost) best_difference_score = math.inf for sharding_spec, info_pairs in valid_transform_spec_dict.items(): comm_spec, cost = info_pairs spec_difference = sharding_spec.sharding_sequence_difference(target_spec) if spec_difference == 0: total_cost += cost transform_path.append(sharding_spec) comm_action_sequence.append(comm_spec) self.cached_spec_pairs_transform_path[spec_pairs] = (transform_path, comm_action_sequence) return (transform_path, comm_action_sequence, total_cost) if spec_difference < best_difference_score: temp_sharding_spec = sharding_spec temp_cost = cost temp_comm_spec = comm_spec best_difference_score = spec_difference transform_path.append(temp_sharding_spec) comm_action_sequence.append(temp_comm_spec) total_cost += temp_cost total_steps += 1 raise RuntimeError(f"Could not find a valid transform path with in {MAX_TRANSFORM_STEPS} steps.") def apply(self, tensor_with_sharding_spec, target_spec): ''' Apply target_spec to tensor with source sharding spec, the transform path is generated by the shape_consistency method. Argument: tensor_with_sharding_spec (torch.Tensor): a tensor with source sharding spec to be transformed to the target spec. target_spec (ShardingSpec): The tensor transform processes will be directed by the target_spec. Example: physical_mesh_id = torch.arange(0, 4) mesh_shape = (2, 2) # [[0, 1, # [2, 3]] device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) entire_shape = torch.Size((4, 2)) shape_consistency_manager = ShapeConsistencyManager() dim_partition_source = {0: [0]} dim_partition_target = {1: [0]} # DistSpec: # shard_sequence: S0,R # device_mesh_shape: (2, 2) sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source) # DistSpec: # shard_sequence: R,S0 # device_mesh_shape: (2, 2) sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target) if rank in (0, 1): sharded_tensor_0 = torch.zeros(2, 1) sharded_tensor_1 = torch.ones(2, 1) # tensor([[0., 1.], # [0., 1.]]) tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda() if rank in (2, 3): sharded_tensor_0 = torch.ones(2, 1) * 2 sharded_tensor_1 = torch.ones(2, 1) * 3 # tensor([[2., 3.], # [2., 3.]]) tensor_to_comm = torch.cat((sharded_tensor_0, sharded_tensor_1), 1).cuda() tensor_to_comm.sharding_spec = sharding_spec_source shape_consistency_manager.apply(tensor_to_comm, sharding_spec_target) print(tensor_to_comm) Output in rank0 and rank2: tensor([[0.], [0.], [2.], [2.]]) Output in rank1 and rank3: tensor([[1.], [1.], [3.], [3.]]) ''' _, comm_action_sequence, _ = self.shape_consistency(tensor_with_sharding_spec.sharding_spec, target_spec) for comm_spec in comm_action_sequence: comm_spec.covert_spec_to_action(tensor_with_sharding_spec) tensor_with_sharding_spec.sharding_spec = target_spec