import operator from functools import reduce from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost from colossalai.tensor.shape_consistency import CollectiveCommPattern from .strategy_generator import StrategyGenerator_V2 from typing import List class MatMulStrategyGenerator(StrategyGenerator_V2): """ MatMulStrategyGenerator is a generic class to cover all matrix multiplication cases. The operation data is defined as `output = input x other + bias`. """ @property def has_bias(self): return 'bias' in self.op_data def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: size_mapping = { 'input': self._compute_size_in_bytes(strategy, "input"), 'other': self._compute_size_in_bytes(strategy, "other"), 'output': self._compute_size_in_bytes(strategy, "output") } if self.has_bias: bias_size = self._compute_size_in_bytes(strategy, "bias") size_mapping['bias'] = bias_size # compute fwd cost incurred # fwd_cost = input + other + bias + output fwd_activation_cost = sum([v for k, v in size_mapping.items() if not self.is_param(k)]) fwd_parameter_cost = sum([v for k, v in size_mapping.items() if self.is_param(k)]) fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) # compute bwd cost incurred # bwd_cost = input_grad + bias_grad bwd_activation_cost = sum([v for k, v in size_mapping.items() if k in ['input', 'other', 'bias']]) bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0) # compute total cost total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, parameter=fwd_parameter_cost + 0) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) strategy.memory_cost = memory_cost class DotProductStrategyGenerator(MatMulStrategyGenerator): def validate(self) -> bool: input_op_data = self.op_data['input'] other_op_data = self.op_data['other'] assert input_op_data.data.dim() == 1 and other_op_data.data.dim() == 1 def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() fwd_compute_cost = sharded_input_shape[0] bwd_compute_cost = sharded_input_shape * 2 compute_cost = TrainCycleItem(fwd=fwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) return compute_cost def no_split(self): name = f'R = R dot R' dim_partition_dict = {"input": {}, "other": {}, "output": {}, 'bias': {}} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) communication_action_mapping = {} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) def split_one_dim(self, mesh_dim): name = f'R = S{mesh_dim} dot S{mesh_dim}' # get sharding spec dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}, "bias": {0: [mesh_dim]}} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication action output_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['output'], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim) communication_action_mapping = {"output": output_comm_spec} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) def generate(self) -> List[ShardingStrategy_V2]: strategy_list = [] # do not split dimensions for dot product # R = R dot R strategy_list.append(self.no_split()) # split two tensors in the same dimensions # S = S dot S strategy_list.append(self.split_one_dim(0)) strategy_list.append(self.split_one_dim(1)) return strategy_list class MatVecStrategyGenerator(MatMulStrategyGenerator): def validate(self) -> bool: input_op_data = self.op_data['input'] other_op_data = self.op_data['other'] assert input_op_data.data.dim() > 1 and other_op_data.data.dim() == 1 def no_split(self): name = "R = R x R" dim_partition_dict = {"input": {}, "other": {}, "output": {}, "bias": {}} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping={}) def split_input_batch(self, mesh_dim): name = f'S{mesh_dim}R = S{mesh_dim}R x R' # get sharding spec dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication action other_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['other'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim) communication_action_mapping = {'other': other_comm_spec} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) def generate(self) -> List[ShardingStrategy_V2]: strategy_list = [] # no split strategy_list.append(self.no_split()) # split the batch dim for the first tensor only strategy_list.append(self.split_input_batch(0)) strategy_list.append(self.split_input_batch(1)) return strategy_list class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: # C = AB # C: [M, N], A: [M, P], B: [P, N] # fwd cost = MNP (only count mul) # bwd: 2 x fwd_cost sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() sharded_other_shape = strategy.sharding_specs[self.op_data['other']].get_sharded_shape_per_device() dim_m_val = reduce(operator.mul, sharded_input_shape[:-1]) dim_n_val = sharded_other_shape[-1] dim_p_val = sharded_other_shape[0] fwd_compute_cost = dim_m_val * dim_n_val * dim_p_val bwd_compute_cost = fwd_compute_cost * 2 compute_cost = TrainCycleItem(fwd=bwd_compute_cost, bwd=bwd_compute_cost, total=fwd_compute_cost + bwd_compute_cost) strategy.compute_cost = compute_cost def generate(self) -> List[ShardingStrategy_V2]: strategies = [] # SS = SR x RS strategies.append(self.split_lhs_space_rhs_space(0, 1)) strategies.append(self.split_lhs_space_rhs_space(1, 0)) # SR = SS x SR strategies.append(self.split_lhs_space_both_contract(0, 1)) strategies.append(self.split_lhs_space_both_contract(1, 0)) # RS = RS x SS strategies.append(self.split_rhs_space_both_contract(0, 1)) strategies.append(self.split_rhs_space_both_contract(1, 0)) # RR= RS x SR strategies.append(self.recompute_split_both_contract(0)) strategies.append(self.recompute_split_both_contract(1)) # RS = RR x RS strategies.append(self.split_rhs_space_only(0)) strategies.append(self.split_rhs_space_only(1)) # S01R = S01R x RR strategies.append(self.split_lhs_1st_dim_1d(0, 1)) # RR = RS01 x S01R strategies.append(self.split_lhs_2nd_dim_1d(0, 1)) # RS01 = RR x RS01 strategies.append(self.split_rhs_2nd_dim_1d(0, 1)) # update mete info on cost for strategy in strategies: self.update_communication_cost(strategy) self.update_compute_cost(strategy) self.update_memory_cost(strategy) return strategies def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): # handle case SS = SR x RS name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' dim_partition_dict_mapping = { "input": { 0: [mesh_dim_0] }, "other": { -1: [mesh_dim_1] }, "bias": { -1: [mesh_dim_1] }, "output": { 0: [mesh_dim_0], -1: [mesh_dim_1] }, } sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # set communication action input_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1) other_comm_spec = self.get_communication_spec( sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0) communication_action_mapping = {"input": input_comm_spec, "other": other_comm_spec} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): # handle the case SR = SS x SR name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' # get sharding spec mapping dim_partition_dict_mapping = { "input": { 0: [mesh_dim_0], -1: [mesh_dim_1] }, "other": { 0: [mesh_dim_1] }, "bias": {}, "output": { 0: [mesh_dim_0] }, } sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action mapping input_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0) output_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_1) communication_action_mapping = {"input": input_comm_spec, 'output': output_comm_spec} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' # get sharding specs dim_partition_dict_mapping = { "input": { -1: [mesh_dim_0] }, "other": { 0: [mesh_dim_0], -1: [mesh_dim_1] }, "bias": { -1: [mesh_dim_1] }, "output": { -1: [mesh_dim_1] }, } sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication actions output_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['output'], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_0) input_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['input'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1) communication_action_mapping = {"output": output_comm_spec, "input": input_comm_spec} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) def recompute_split_both_contract(self, mesh_dim): name = f'RR = RS{mesh_dim} x S{mesh_dim}R' # get sharding spec dim_partition_dict_mapping = { "input": { -1: [mesh_dim] }, "other": { 0: [mesh_dim] }, "bias": {}, "output": {}, } sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action output_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['output'], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim) communication_action_mapping = {'output': output_comm_spec} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) def split_rhs_space_only(self, mesh_dim): name = f'RS{mesh_dim} = RR x RS{mesh_dim}' # get sharding spec dim_partition_dict_mapping = { "input": {}, "other": { -1: [mesh_dim] }, "bias": { -1: [mesh_dim] }, "output": { -1: [mesh_dim] }, } sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication actions input_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['input'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim) communication_action_mapping = {'input': input_comm_spec} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1): name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' # get sharding spec dim_partition_dict_mapping = { "input": { 0: [mesh_dim_0, mesh_dim_1] }, "other": {}, "bias": {}, "output": { 0: [mesh_dim_0, mesh_dim_1] }, } sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action other_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['other'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1]) communcation_action_mapping = {"other": other_comm_spec} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communcation_action_mapping) def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' # get sharding spec dim_partition_dict_mapping = { "input": { -1: [mesh_dim_0, mesh_dim_1] }, "other": { 0: [mesh_dim_0, mesh_dim_1] }, "bias": {}, "output": {}, } sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action output_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['output'], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1]) communication_action_mapping = {'output': output_comm_spec} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' # get sharding spec dim_partition_dict_mapping = { "input": {}, "other": { -1: [mesh_dim_0, mesh_dim_1] }, "bias": { -1: [mesh_dim_0, mesh_dim_1] }, "output": { -1: [mesh_dim_0, mesh_dim_1] }, } sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) # get communication action input_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['input'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1]) communication_action_mapping = {'input': input_comm_spec} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) def validate(self) -> bool: assert "input" in self.op_data assert "other" in self.op_data # make sure the other has 2 dim input_data = self.op_data['input'] other_data = self.op_data['other'] assert input_data.data.dim() > 0 and other_data.data.dim() == 2 assert other_data.logical_shape[0] == input_data.logical_shape[-1] # check if bias has the same a valid dim has_bias = "bias" in self.op_data if has_bias: bias_data = self.op_data['bias'] assert bias_data.logical_shape[-1] == other_data.logical_shape[-1] class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): """ Generate sharding strategies for the batched matrix multiplication. A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j] """ def validate(self) -> bool: input_op_data = self.op_data['input'] other_op_data = self.op_data['other'] assert input_op_data.data.dim() > 2 or other_op_data.data.dim() > 2 def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: return self.op_data['input'].data.shape[-1] * reduce(operator.mul, self.op_data['output'].data.shape) def split_one_batch_dim(self): device_mesh_is_1d = True if len(self.device_mesh.mesh_shape) == 1: mesh_dim = 0 elif 1 in self.device_mesh.mesh_shape: mesh_dim = self.device_mesh.mesh_shape.index(1) else: device_mesh_is_1d = False if device_mesh_is_1d: name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}' # get sharding_spec dim_partition_dict = { "input": { 0: [mesh_dim] }, "other": { 0: [mesh_dim] }, "bias": {}, "output": { 0: [mesh_dim] } } sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication actions communication_action_mapping = {} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) else: return None def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1): name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}' dim_partition_dict = { "input": { 0: [mesh_dim_0, mesh_dim_1] }, "other": { 0: [mesh_dim_0, mesh_dim_1] }, "bias": {}, "output": { 0: [mesh_dim_0, mesh_dim_1] } } sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication actions communication_action_mapping = {} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1): name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}' dim_partition_dict = { "input": { 0: [mesh_dim_0], -2: [mesh_dim_1] }, "other": { 0: [mesh_dim_0] }, "bias": {}, "output": { 0: [mesh_dim_0], -2: [mesh_dim_1] } } sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication actions other_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['other'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1) communication_action_mapping = {'other': other_comm_spec} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1): name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}' dim_partition_dict = { "input": { 0: [mesh_dim_0] }, "other": { 0: [mesh_dim_0], -1: [mesh_dim_1] }, "bias": { -1: [mesh_dim_1] }, "output": { 0: [mesh_dim_0], -1: [mesh_dim_1] } } sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication actions input_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['input'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1) communication_action_mapping = {'input': input_comm_spec} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1): name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}' dim_partition_dict = { "input": { 0: [mesh_dim_0], -1: [mesh_dim_1] }, "other": { 0: [mesh_dim_0], -2: [mesh_dim_1] }, "bias": {}, "output": { 0: [mesh_dim_0], -2: [mesh_dim_1] } } sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication actions output_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['output'], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_1) communication_action_mapping = {'output': output_comm_spec} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) def generate(self) -> List[ShardingStrategy_V2]: strategy_list = [] # split only the batch dimension # Sb = Sb x Sb # can be None as it is only for 1D device mesh strategy = self.split_one_batch_dim() if strategy: # only for 1D device mesh strategy_list.append(strategy) else: # for 2D device mesh # split batch dim of two inputs and the i dim of the first tensor # SbSi = SbSi x Sb strategy_list.append(self.split_batch_dim_lhs_space(0, 1)) strategy_list.append(self.split_batch_dim_lhs_space(1, 0)) # split batch dim of two inputs and the j of the second tensor # SbSj = Sb x SbSj strategy_list.append(self.split_batch_dim_rhs_space(0, 1)) strategy_list.append(self.split_batch_dim_rhs_space(1, 0)) # split batch dim of two inputs and the k dim of two inputs # Sb = SbSk x SbSk, need to all-reduce by k dim strategy_list.append(self.split_batch_dim_both_contract(0, 1)) strategy_list.append(self.split_batch_dim_both_contract(1, 0)) # split two batch dim strategy_list.append(self.split_two_batch_dim(0, 1)) strategy_list.append(self.split_two_batch_dim(1, 0)) return strategy_list