|
|
|
@ -218,7 +218,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|
|
|
|
sharding_spec=sharding_spec_mapping["output"], |
|
|
|
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, |
|
|
|
|
logical_process_axis=mesh_dim_0, |
|
|
|
|
comm_type=CommType.AFTER) |
|
|
|
|
comm_type=CommType.IMPLICIT) |
|
|
|
|
|
|
|
|
|
communication_action_mapping = {"output": output_comm_action} |
|
|
|
|
|
|
|
|
@ -254,7 +254,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|
|
|
|
sharding_spec=sharding_spec_mapping["output"], |
|
|
|
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, |
|
|
|
|
logical_process_axis=[mesh_dim_0, mesh_dim_1], |
|
|
|
|
comm_type=CommType.AFTER) |
|
|
|
|
comm_type=CommType.IMPLICIT) |
|
|
|
|
|
|
|
|
|
communication_action_mapping = {"output": output_comm_action} |
|
|
|
|
|
|
|
|
@ -300,7 +300,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|
|
|
|
sharding_spec=sharding_spec_mapping["output"], |
|
|
|
|
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, |
|
|
|
|
logical_process_axis=[mesh_dim_0], |
|
|
|
|
comm_type=CommType.AFTER) |
|
|
|
|
comm_type=CommType.IMPLICIT) |
|
|
|
|
|
|
|
|
|
communication_action_mapping = {"output": output_comm_action} |
|
|
|
|
|
|
|
|
@ -331,14 +331,14 @@ class BatchNormStrategyGenerator(StrategyGenerator):
|
|
|
|
|
# TODO: The strategies below should be uncommented after runtime |
|
|
|
|
# passes ready. |
|
|
|
|
# SR = SR x R WITH SYNC_BN |
|
|
|
|
# strategy_list.append(self.split_input_batch(0)) |
|
|
|
|
# strategy_list.append(self.split_input_batch(1)) |
|
|
|
|
strategy_list.append(self.split_input_batch(0)) |
|
|
|
|
strategy_list.append(self.split_input_batch(1)) |
|
|
|
|
|
|
|
|
|
# SS = SS x S WITH SYNC_BN |
|
|
|
|
# strategy_list.append(self.split_input_both_dim(0, 1)) |
|
|
|
|
# strategy_list.append(self.split_input_both_dim(1, 0)) |
|
|
|
|
strategy_list.append(self.split_input_both_dim(0, 1)) |
|
|
|
|
strategy_list.append(self.split_input_both_dim(1, 0)) |
|
|
|
|
|
|
|
|
|
# S01R = S01R x R WITH SYNC_BN |
|
|
|
|
# strategy_list.append(self.split_input_batch_1d(0, 1)) |
|
|
|
|
strategy_list.append(self.split_input_batch_1d(0, 1)) |
|
|
|
|
|
|
|
|
|
return strategy_list |
|
|
|
|