diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py index 70a2cc9b4..8400a56c8 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py @@ -22,10 +22,6 @@ class BatchNormStrategyGenerator(StrategyGenerator): In this generator, both methods will be considered. """ - @property - def has_bias(self): - return 'bias' in self.op_data - def validate(self) -> bool: ''' In sanity check, we need make sure the input data having correct dimension size. diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py index 88d363447..fe40cc1a9 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py @@ -17,10 +17,6 @@ class ConvStrategyGenerator(StrategyGenerator): The operation data is defined as `output = input x other + bias`. """ - @property - def has_bias(self): - return 'bias' in self.op_data - def validate(self) -> bool: ''' In sanity check, we need make sure the input data having correct dimension size. diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py index 59e0ee4c8..dae168cbb 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py @@ -19,10 +19,6 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator): 3. args_0._meta_data: Tuple[torch.Tensor], args_1._meta_data: int """ - @property - def has_bias(self): - return 'bias' in self.op_data - def validate(self) -> bool: return super().validate() diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py index 86a70e5d0..cf7530fa6 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py @@ -18,10 +18,6 @@ class LayerNormGenerator(StrategyGenerator): The operation data is defined as `output = input x other + bias`. """ - @property - def has_bias(self): - return 'bias' in self.op_data - def validate(self) -> bool: return super().validate() diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py index d36800e29..26fcacc57 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -14,10 +14,6 @@ class MatMulStrategyGenerator(StrategyGenerator): The operation data is defined as `output = input x other + bias`. """ - @property - def has_bias(self): - return 'bias' in self.op_data - def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: size_mapping = { 'input': self._compute_size_in_bytes(strategy, "input"), @@ -512,11 +508,13 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication actions - bias_comm_spec = self.get_communication_spec( - sharding_spec=sharding_spec_mapping['bias'], - communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim) - communication_action_mapping = {"bias": bias_comm_spec} + communication_action_mapping = {} + if self.has_bias: + bias_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim) + communication_action_mapping['bias'] = bias_comm_spec return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -538,11 +536,14 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication actions - bias_comm_spec = self.get_communication_spec( - sharding_spec=sharding_spec_mapping['bias'], - communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=[mesh_dim_0, mesh_dim_1]) - communication_action_mapping = {"bias": bias_comm_spec} + communication_action_mapping = {} + if self.has_bias: + bias_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1]) + communication_action_mappingp['bias'] = bias_comm_spec + return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -566,15 +567,20 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication actions + communication_action_mapping = {} other_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['other'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1) - bias_comm_spec = self.get_communication_spec( - sharding_spec=sharding_spec_mapping['bias'], - communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=[mesh_dim_0, mesh_dim_1]) - communication_action_mapping = {'other': other_comm_spec, 'bias': bias_comm_spec} + communication_action_mapping['other'] = other_comm_spec + + if self.has_bias: + bias_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1]) + communication_action_mapping['bias'] = bias_comm_spec + return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -600,15 +606,20 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication actions + communication_action_mapping = {} input_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['input'], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_1) - bias_comm_spec = self.get_communication_spec( - sharding_spec=sharding_spec_mapping['bias'], - communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0) - communication_action_mapping = {'input': input_comm_spec, 'bias': bias_comm_spec} + communication_action_mapping['input'] = input_comm_spec + + if self.has_bias: + bias_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0) + communication_action_mapping['bias'] = bias_comm_spec + return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) @@ -633,15 +644,20 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict) # get communication actions + communication_action_mapping = {} output_comm_spec = self.get_communication_spec( sharding_spec=sharding_spec_mapping['output'], communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, logical_process_axis=mesh_dim_1) - bias_comm_spec = self.get_communication_spec( - sharding_spec=sharding_spec_mapping['bias'], - communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0) - communication_action_mapping = {'output': output_comm_spec, 'bias': bias_comm_spec} + communication_action_mapping['output'] = output_comm_spec + + if self.has_bias: + bias_comm_spec = self.get_communication_spec( + sharding_spec=sharding_spec_mapping['bias'], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0) + communication_action_mapping['bias'] = bias_comm_spec + return self.get_sharding_strategy(name=name, sharding_spec_mapping=sharding_spec_mapping, communication_action_mapping=communication_action_mapping) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py index a643968ba..9ec0c0bc4 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py @@ -24,6 +24,13 @@ class StrategyGenerator(ABC): self.op_data = operation_data_mapping self.device_mesh = device_mesh + @property + def has_bias(self): + """ + A utility method to check for the existence of bias operand for convenience. + """ + return 'bias' in self.op_data + def is_param(self, op_data_name): other_data = self.op_data[op_data_name] return other_data.type == OperationDataType.PARAM diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py index f3612a781..76cbe6bd5 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py @@ -22,7 +22,6 @@ class BMMTorchFunctionModule(nn.Module): return torch.bmm(x1, x2) -@run_on_environment_flag(name='AUTO_PARALLEL') @pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) def test_2d_device_mesh(module): @@ -93,7 +92,6 @@ def test_2d_device_mesh(module): assert 'Sb1R = Sb1Sk0 x Sb1Sk0' in strategy_name_list -@run_on_environment_flag(name='AUTO_PARALLEL') @pytest.mark.parametrize('module', [BMMTensorMethodModule, BMMTorchFunctionModule]) def test_1d_device_mesh(module): model = module() diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py index 7ff418f25..d47876af2 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py @@ -11,7 +11,6 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear from colossalai.testing.pytest_wrapper import run_on_environment_flag -@run_on_environment_flag(name='AUTO_PARALLEL') def test_norm_pool_handler(): model = nn.Sequential(nn.MaxPool2d(4, padding=1).to('meta')) tracer = ColoTracer() @@ -50,7 +49,7 @@ def test_norm_pool_handler(): assert mapping['output'].data.shape == torch.Size([4, 4, 16, 16]) assert mapping['output'].type == OperationDataType.OUTPUT - strategies_vector = handler.register_strategy() + strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategy_name_list = [val.name for val in strategies_vector] assert len(strategy_name_list) == 9