mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] fix conv handler numerical test (#1771)
parent
1e88811c7a
commit
27de252334
|
@ -141,14 +141,31 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.HOOK)
|
||||
communication_action_mapping["other"] = other_comm_action
|
||||
|
||||
if self.has_bias and self.is_param("bias"):
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
else:
|
||||
other_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["other"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.HOOK)
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=1)
|
||||
|
||||
communication_action_mapping["other"] = other_comm_action
|
||||
|
||||
if self.has_bias:
|
||||
if self.is_param('bias'):
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.HOOK)
|
||||
else:
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.BEFORE,
|
||||
key_for_kwarg='bias')
|
||||
communication_action_mapping["bias"] = bias_comm_action
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
|
@ -180,14 +197,31 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.HOOK)
|
||||
communication_action_mapping["other"] = other_comm_action
|
||||
|
||||
if self.has_bias and self.is_param("bias"):
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
else:
|
||||
other_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["other"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.HOOK)
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=1)
|
||||
|
||||
communication_action_mapping["other"] = other_comm_action
|
||||
|
||||
if self.has_bias:
|
||||
if self.is_param('bias'):
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.HOOK)
|
||||
else:
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.BEFORE,
|
||||
key_for_kwarg='bias')
|
||||
communication_action_mapping["bias"] = bias_comm_action
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
|
@ -230,14 +264,29 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.HOOK)
|
||||
communication_action_mapping["other"] = other_comm_action
|
||||
|
||||
if self.has_bias and self.is_param("bias"):
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
else:
|
||||
other_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["other"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.HOOK)
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=1)
|
||||
communication_action_mapping["other"] = other_comm_action
|
||||
if self.has_bias:
|
||||
if self.is_param("bias"):
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.HOOK)
|
||||
else:
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
comm_type=CommType.BEFORE,
|
||||
key_for_kwarg='bias')
|
||||
communication_action_mapping["bias"] = bias_comm_action
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
|
@ -277,7 +326,7 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||
input_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["input"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=mesh_dim_0,
|
||||
logical_process_axis=mesh_dim_1,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=0)
|
||||
|
||||
|
@ -399,14 +448,30 @@ class ConvStrategyGenerator(StrategyGenerator):
|
|||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||
comm_type=CommType.HOOK)
|
||||
communication_action_mapping["other"] = other_comm_action
|
||||
|
||||
if self.has_bias and self.is_param("bias"):
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
else:
|
||||
other_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["other"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||
comm_type=CommType.HOOK)
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=1)
|
||||
|
||||
communication_action_mapping["other"] = other_comm_action
|
||||
|
||||
if self.has_bias:
|
||||
if self.is_param("bias"):
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||
comm_type=CommType.HOOK)
|
||||
else:
|
||||
bias_comm_action = self.get_communication_action(
|
||||
sharding_spec_mapping["bias"],
|
||||
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
|
||||
logical_process_axis=[mesh_dim_0, mesh_dim_1],
|
||||
comm_type=CommType.BEFORE,
|
||||
key_for_kwarg='bias')
|
||||
communication_action_mapping["bias"] = bias_comm_action
|
||||
|
||||
return self.get_sharding_strategy(name=name,
|
||||
|
|
|
@ -290,7 +290,6 @@ def check_conv_function_handler(rank, bias, world_size, port):
|
|||
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1]
|
||||
|
||||
|
||||
@pytest.mark.skip("some cases need to be fixed")
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
# We temporarily ban the bias option before doing bias add
|
||||
|
@ -303,7 +302,6 @@ def test_conv_module_handler(bias=False):
|
|||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
@pytest.mark.skip("some cases need to be fixed")
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@pytest.mark.dist
|
||||
# We temporarily ban the bias option before doing bias add
|
||||
|
|
Loading…
Reference in New Issue