From 247a9dbca96672306bfb2a13467c161453d8c33b Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Thu, 29 Sep 2022 11:08:05 +0800 Subject: [PATCH] [autoparallel] added bias comm spec to matmul strategy (#1664) --- .../strategy/matmul_strategy_generator.py | 115 +++++++++++------- 1 file changed, 68 insertions(+), 47 deletions(-) diff --git a/colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py index b64152d0e..a5a9ec58a 100644 --- a/colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py @@ -1,3 +1,4 @@ +from audioop import bias import operator from functools import reduce from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost @@ -121,7 +122,7 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): name = f'S{mesh_dim}R = S{mesh_dim}R x R' # get sharding spec - dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}} + dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}, "bias": {}} sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication action @@ -129,7 +130,11 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): sharding_spec=sharding_spec_mapping['other'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim) - communication_action_mapping = {'other': other_comm_spec} + bias_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim) + communication_action_mapping = {'other': other_comm_spec, 'bias': bias_comm_spec} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -236,8 +241,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0) + bias_comm_spec = self.get_communication_spec( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0) - communication_action_mapping = {"input": input_comm_spec, "other": other_comm_spec} + communication_action_mapping = {"input": input_comm_spec, "other": other_comm_spec, "bias": bias_comm_spec} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -272,8 +281,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec=sharding_spec_mapping["output"], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_1) + bias_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, + logical_process_axis=mesh_dim_1) - communication_action_mapping = {"input": input_comm_spec, 'output': output_comm_spec} + communication_action_mapping = {"input": input_comm_spec, 'output': output_comm_spec, 'bias': bias_comm_spec} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, @@ -390,8 +403,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): sharding_spec=sharding_spec_mapping['other'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1]) + bias_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1]) - communcation_action_mapping = {"other": other_comm_spec} + communcation_action_mapping = {"other": other_comm_spec, "bias": bias_comm_spec} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communcation_action_mapping) @@ -486,40 +503,22 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: return self.op_data['input'].data.shape[-1] * reduce(operator.mul, self.op_data['output'].data.shape) - def split_one_batch_dim(self): - device_mesh_is_1d = True - if len(self.device_mesh.mesh_shape) == 1: - mesh_dim = 0 - elif 1 in self.device_mesh.mesh_shape: - mesh_dim = self.device_mesh.mesh_shape.index(1) - else: - device_mesh_is_1d = False + def split_one_batch_dim(self, mesh_dim): + name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}' - if device_mesh_is_1d: - name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}' + # get sharding_spec + dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}} + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) - # get sharding_spec - dim_partition_dict = { - "input": { - 0: [mesh_dim] - }, - "other": { - 0: [mesh_dim] - }, - "bias": {}, - "output": { - 0: [mesh_dim] - } - } - sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) - - # get communication actions - communication_action_mapping = {} - return self.get_sharding_strategy(name=name, - sharding_spec_mapping=sharding_spec_mapping, - communication_action_mapping=communication_action_mapping) - else: - return None + # get communication actions + bias_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim) + communication_action_mapping = {"bias": bias_comm_spec} + return self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1): name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}' @@ -538,7 +537,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication actions - communication_action_mapping = {} + bias_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1]) + communication_action_mapping = {"bias": bias_comm_spec} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -566,7 +569,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): sharding_spec=sharding_spec_mapping['other'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1) - communication_action_mapping = {'other': other_comm_spec} + bias_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1]) + communication_action_mapping = {'other': other_comm_spec, 'bias': bias_comm_spec} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -596,7 +603,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): sharding_spec=sharding_spec_mapping['input'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1) - communication_action_mapping = {'input': input_comm_spec} + bias_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0) + communication_action_mapping = {'input': input_comm_spec, 'bias': bias_comm_spec} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -625,21 +636,31 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): sharding_spec=sharding_spec_mapping['output'], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_1) - communication_action_mapping = {'output': output_comm_spec} + bias_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0) + communication_action_mapping = {'output': output_comm_spec, 'bias': bias_comm_spec} return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) def generate(self) -> List[ShardingStrategy_V2]: strategy_list = [] + device_mesh_is_1d = True + if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape: + device_mesh_is_1d = False - # split only the batch dimension - # Sb = Sb x Sb - # can be None as it is only for 1D device mesh - strategy = self.split_one_batch_dim() - if strategy: + if device_mesh_is_1d: + # split only the batch dimension + # Sb = Sb x Sb + # can be None as it is only for 1D device mesh # only for 1D device mesh - strategy_list.append(strategy) + if len(self.device_mesh.mesh_shape) == 1: + mesh_dim = 0 + else: + mesh_dim = self.device_mesh.mesh_shape.index(1) + strategy_list.append(self.split_one_batch_dim(mesh_dim)) else: # for 2D device mesh # split batch dim of two inputs and the i dim of the first tensor