From 27de252334adcfef44f5adfef2a287927501cdf9 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Tue, 1 Nov 2022 10:43:44 +0800 Subject: [PATCH] [autoparallel] fix conv handler numerical test (#1771) --- .../strategy/conv_strategy_generator.py | 109 ++++++++++++++---- .../test_node_handler/test_conv_handler.py | 2 - 2 files changed, 87 insertions(+), 24 deletions(-) diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py index f7e4543f8..c2154b310 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py @@ -141,14 +141,31 @@ class ConvStrategyGenerator(StrategyGenerator): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.HOOK) - communication_action_mapping["other"] = other_comm_action - if self.has_bias and self.is_param("bias"): - bias_comm_action = self.get_communication_action( - sharding_spec_mapping["bias"], + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping["other"] = other_comm_action + + if self.has_bias: + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + key_for_kwarg='bias') communication_action_mapping["bias"] = bias_comm_action return self.get_sharding_strategy(name=name, @@ -180,14 +197,31 @@ class ConvStrategyGenerator(StrategyGenerator): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.HOOK) - communication_action_mapping["other"] = other_comm_action - if self.has_bias and self.is_param("bias"): - bias_comm_action = self.get_communication_action( - sharding_spec_mapping["bias"], + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping["other"] = other_comm_action + + if self.has_bias: + if self.is_param('bias'): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + key_for_kwarg='bias') communication_action_mapping["bias"] = bias_comm_action return self.get_sharding_strategy(name=name, @@ -230,14 +264,29 @@ class ConvStrategyGenerator(StrategyGenerator): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, comm_type=CommType.HOOK) - communication_action_mapping["other"] = other_comm_action - if self.has_bias and self.is_param("bias"): - bias_comm_action = self.get_communication_action( - sharding_spec_mapping["bias"], + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=mesh_dim_0, - comm_type=CommType.HOOK) + comm_type=CommType.BEFORE, + arg_index=1) + communication_action_mapping["other"] = other_comm_action + if self.has_bias: + if self.is_param("bias"): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=mesh_dim_0, + comm_type=CommType.BEFORE, + key_for_kwarg='bias') communication_action_mapping["bias"] = bias_comm_action return self.get_sharding_strategy(name=name, @@ -277,7 +326,7 @@ class ConvStrategyGenerator(StrategyGenerator): input_comm_action = self.get_communication_action( sharding_spec_mapping["input"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, - logical_process_axis=mesh_dim_0, + logical_process_axis=mesh_dim_1, comm_type=CommType.BEFORE, arg_index=0) @@ -399,14 +448,30 @@ class ConvStrategyGenerator(StrategyGenerator): communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], comm_type=CommType.HOOK) - communication_action_mapping["other"] = other_comm_action - - if self.has_bias and self.is_param("bias"): - bias_comm_action = self.get_communication_action( - sharding_spec_mapping["bias"], + else: + other_comm_action = self.get_communication_action( + sharding_spec_mapping["other"], communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, logical_process_axis=[mesh_dim_0, mesh_dim_1], - comm_type=CommType.HOOK) + comm_type=CommType.BEFORE, + arg_index=1) + + communication_action_mapping["other"] = other_comm_action + + if self.has_bias: + if self.is_param("bias"): + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.HOOK) + else: + bias_comm_action = self.get_communication_action( + sharding_spec_mapping["bias"], + communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, + logical_process_axis=[mesh_dim_0, mesh_dim_1], + comm_type=CommType.BEFORE, + key_for_kwarg='bias') communication_action_mapping["bias"] = bias_comm_action return self.get_sharding_strategy(name=name, diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py index dbacb5ec4..2acd015c8 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py @@ -290,7 +290,6 @@ def check_conv_function_handler(rank, bias, world_size, port): assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1] -@pytest.mark.skip("some cases need to be fixed") @run_on_environment_flag(name='AUTO_PARALLEL') @pytest.mark.dist # We temporarily ban the bias option before doing bias add @@ -303,7 +302,6 @@ def test_conv_module_handler(bias=False): mp.spawn(run_func, nprocs=world_size) -@pytest.mark.skip("some cases need to be fixed") @run_on_environment_flag(name='AUTO_PARALLEL') @pytest.mark.dist # We temporarily ban the bias option before doing bias add