diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py index 37ff3c3ab..59091dab5 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py @@ -152,7 +152,10 @@ class LinearModuleHandler(MetaInfoModuleHandler): op_data_mapping = self.get_operation_data_mapping() generators = [] generators.append( - LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) + LinearProjectionStrategyGenerator(op_data_mapping, + self.device_mesh, + linear_projection_type='linear', + solver_perference=self.solver_perference)) return generators def get_operation_data_mapping(self) -> Dict[str, OperationData]: diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py index fa2246f95..5d70e131d 100644 --- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py +++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py @@ -3,6 +3,7 @@ from ast import arg from functools import reduce from typing import List +from colossalai.auto_parallel.tensor_shard.options import SolverPerference from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( CommType, MemoryCost, @@ -209,9 +210,14 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): - def __init__(self, operation_data_mapping, device_mesh, linear_projection_type='linear'): + def __init__(self, + operation_data_mapping, + device_mesh, + linear_projection_type='linear', + solver_perference=SolverPerference.STANDARD): super().__init__(operation_data_mapping, device_mesh) self.linear_projection_type = linear_projection_type + self.solver_perference = solver_perference def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: # C = AB @@ -231,16 +237,22 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): total=fwd_compute_cost + bwd_compute_cost) strategy.compute_cost = compute_cost - def collate_strategies(self) -> List[ShardingStrategy]: + def dp_strategies(self) -> List[ShardingStrategy]: strategies = [] - # SS = SR x RS - strategies.append(self.split_lhs_space_rhs_space(0, 1)) - strategies.append(self.split_lhs_space_rhs_space(1, 0)) + # S01R = S01R x RR + strategies.append(self.split_lhs_1st_dim_1d(0, 1)) - # SR = SS x SR - strategies.append(self.split_lhs_space_both_contract(0, 1)) - strategies.append(self.split_lhs_space_both_contract(1, 0)) + return strategies + + def tp_strategies(self) -> List[ShardingStrategy]: + strategies = [] + + # RR = RS01 x S01R + strategies.append(self.split_lhs_2nd_dim_1d(0, 1)) + + # RS01 = RR x RS01 + strategies.append(self.split_rhs_2nd_dim_1d(0, 1)) # RS = RS x SS strategies.append(self.split_rhs_space_both_contract(0, 1)) @@ -254,20 +266,38 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): strategies.append(self.split_rhs_space_only(0)) strategies.append(self.split_rhs_space_only(1)) - # S01R = S01R x RR - strategies.append(self.split_lhs_1st_dim_1d(0, 1)) + return strategies - # RR = RS01 x S01R - strategies.append(self.split_lhs_2nd_dim_1d(0, 1)) + def mix_strategies(self) -> List[ShardingStrategy]: + strategies = [] - # RS01 = RR x RS01 - strategies.append(self.split_rhs_2nd_dim_1d(0, 1)) + # SS = SR x RS + strategies.append(self.split_lhs_space_rhs_space(0, 1)) + strategies.append(self.split_lhs_space_rhs_space(1, 0)) + + # SR = SS x SR + strategies.append(self.split_lhs_space_both_contract(0, 1)) + strategies.append(self.split_lhs_space_both_contract(1, 0)) # RR = RR x RR strategies.append(self.non_split()) return strategies + def collate_strategies(self) -> List[ShardingStrategy]: + strategies = [] + + if self.solver_perference == SolverPerference.STANDARD: + strategies.extend(self.dp_strategies()) + strategies.extend(self.tp_strategies()) + strategies.extend(self.mix_strategies()) + elif self.solver_perference == SolverPerference.DP: + strategies.extend(self.dp_strategies()) + elif self.solver_perference == SolverPerference.TP: + strategies.extend(self.tp_strategies()) + + return strategies + @ignore_sharding_exception def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): # handle case SS = SR x RS diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py index 753ecff53..ebeef9870 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py @@ -117,7 +117,7 @@ def check_attention_layer(rank, model_cls, world_size, port): gm = GraphModule(model, graph, model.__class__.__name__) gm.recompile() - strategies_constructor = build_strategy_constructor(graph, device_mesh) + strategies_constructor = build_strategy_constructor(graph, device_mesh, 'standard', 'replicated', 'standard') solution = solve_solution(gm, strategies_constructor, memory_budget=-1) gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor) gm = ModuleWrapper(gm, *sharding_spec_dicts) diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py index b12db1332..af03481d8 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_permute_and_transpose_handler.py @@ -243,79 +243,79 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size, if model_cls.__name__ == 'LinearReshapeModel': if reshape_dims == ((0, 2, 1, 3), (1, 2)): - assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list - assert '[R, S0, R, S1] -> [R, R, S0, S1]_1' in strategy_name_list - assert '[R, R, S0, S1] -> [R, S0, R, S1]_2' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list - assert '[R, S1, R, S0] -> [R, R, S1, S0]_4' in strategy_name_list - assert '[R, R, S1, S0] -> [R, S1, R, S0]_5' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list - assert '[R, R, S0, R] -> [R, S0, R, R]_8' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, S1, R]_10' in strategy_name_list - assert '[R, R, S1, R] -> [R, S1, R, R]_11' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, S01, R]_19' in strategy_name_list - assert '[R, R, S01, R] -> [R, S01, R, R]_20' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list + assert '[R, S0, R, S1] -> [R, R, S0, S1]_12' in strategy_name_list + assert '[R, R, S0, S1] -> [R, S0, R, S1]_13' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list + assert '[R, S1, R, S0] -> [R, R, S1, S0]_15' in strategy_name_list + assert '[R, R, S1, S0] -> [R, S1, R, S0]_16' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_18' in strategy_name_list + assert '[R, R, S0, R] -> [R, S0, R, R]_19' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_21' in strategy_name_list + assert '[R, R, S1, R] -> [R, S1, R, R]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, S01, R]_1' in strategy_name_list + assert '[R, R, S01, R] -> [R, S01, R, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list if reshape_dims == (2, 0, 1, 3): - assert '[S0, R, R, S1] -> [R, S0, R, S1]_0' in strategy_name_list - assert '[R, S0, R, S1] -> [R, R, S0, S1]_1' in strategy_name_list - assert '[R, R, S0, S1] -> [S0, R, R, S1]_2' in strategy_name_list - assert '[S1, R, R, S0] -> [R, S1, R, S0]_3' in strategy_name_list - assert '[R, S1, R, S0] -> [R, R, S1, S0]_4' in strategy_name_list - assert '[R, R, S1, S0] -> [S1, R, R, S0]_5' in strategy_name_list - assert '[S0, R, R, R] -> [R, S0, R, R]_6' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list - assert '[R, R, S0, R] -> [S0, R, R, R]_8' in strategy_name_list - assert '[S1, R, R, R] -> [R, S1, R, R]_9' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, S1, R]_10' in strategy_name_list - assert '[R, R, S1, R] -> [S1, R, R, R]_11' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list - assert '[S01, R, R, R] -> [R, S01, R, R]_18' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, S01, R]_19' in strategy_name_list - assert '[R, R, S01, R] -> [S01, R, R, R]_20' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + assert '[S0, R, R, S1] -> [R, S0, R, S1]_11' in strategy_name_list + assert '[R, S0, R, S1] -> [R, R, S0, S1]_12' in strategy_name_list + assert '[R, R, S0, S1] -> [S0, R, R, S1]_13' in strategy_name_list + assert '[S1, R, R, S0] -> [R, S1, R, S0]_14' in strategy_name_list + assert '[R, S1, R, S0] -> [R, R, S1, S0]_15' in strategy_name_list + assert '[R, R, S1, S0] -> [S1, R, R, S0]_16' in strategy_name_list + assert '[S0, R, R, R] -> [R, S0, R, R]_17' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, S0, R]_18' in strategy_name_list + assert '[R, R, S0, R] -> [S0, R, R, R]_19' in strategy_name_list + assert '[S1, R, R, R] -> [R, S1, R, R]_20' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, S1, R]_21' in strategy_name_list + assert '[R, R, S1, R] -> [S1, R, R, R]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list + assert '[S01, R, R, R] -> [R, S01, R, R]_0' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, S01, R]_1' in strategy_name_list + assert '[R, R, S01, R] -> [S01, R, R, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list if reshape_dims == (1, 3): - assert '[S0, R, R, S1] -> [S0, S1, R, R]_0' in strategy_name_list - assert '[R, S0, R, S1] -> [R, S1, R, S0]_1' in strategy_name_list - assert '[R, R, S0, S1] -> [R, S1, S0, R]_2' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, S0, R, R]_3' in strategy_name_list - assert '[R, S1, R, S0] -> [R, S0, R, S1]_4' in strategy_name_list - assert '[R, R, S1, S0] -> [R, S0, S1, R]_5' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list - assert '[R, S0, R, R] -> [R, R, R, S0]_7' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list - assert '[R, S1, R, R] -> [R, R, R, S1]_10' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list - assert '[R, R, R, S1] -> [R, S1, R, R]_12' in strategy_name_list - assert '[R, R, R, S0] -> [R, S0, R, R]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list - assert '[R, R, R, S0] -> [R, S0, R, R]_16' in strategy_name_list - assert '[R, R, R, S1] -> [R, S1, R, R]_17' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list - assert '[R, S01, R, R] -> [R, R, R, S01]_19' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list - assert '[R, R, R, S01] -> [R, S01, R, R]_22' in strategy_name_list + assert '[S0, R, R, S1] -> [S0, S1, R, R]_11' in strategy_name_list + assert '[R, S0, R, S1] -> [R, S1, R, S0]_12' in strategy_name_list + assert '[R, R, S0, S1] -> [R, S1, S0, R]_13' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, S0, R, R]_14' in strategy_name_list + assert '[R, S1, R, S0] -> [R, S0, R, S1]_15' in strategy_name_list + assert '[R, R, S1, S0] -> [R, S0, S1, R]_16' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list + assert '[R, S0, R, R] -> [R, R, R, S0]_18' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list + assert '[R, S1, R, R] -> [R, R, R, S1]_21' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1, R, R]_10' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0, R, R]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0] -> [R, S0, R, R]_6' in strategy_name_list + assert '[R, R, R, S1] -> [R, S1, R, R]_5' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list + assert '[R, S01, R, R] -> [R, R, R, S01]_1' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, S01, R, R]_4' in strategy_name_list @run_on_environment_flag(name='AUTO_PARALLEL') diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py index b5e8e3277..c43ee292b 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_softmax_handler.py @@ -117,54 +117,54 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port): strategy_name_list = [strategy.name for strategy in split_strategies_vector] if softmax_dim == 0: - assert '[R, R, R, S1] -> [R, R, R, S1]_0' in strategy_name_list - assert '[R, S0, R, S1] -> [R, S0, R, S1]_1' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_3' in strategy_name_list - assert '[R, S1, R, S0] -> [R, S1, R, S0]_4' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_6' in strategy_name_list - assert '[R, S0, R, R] -> [R, S0, R, R]_7' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list - assert '[R, S1, R, R] -> [R, S1, R, R]_10' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list - assert '[R, S01, R, R] -> [R, S01, R, R]_19' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_11' in strategy_name_list + assert '[R, S0, R, S1] -> [R, S0, R, S1]_12' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_14' in strategy_name_list + assert '[R, S1, R, S0] -> [R, S1, R, S0]_15' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_17' in strategy_name_list + assert '[R, S0, R, R] -> [R, S0, R, R]_18' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_20' in strategy_name_list + assert '[R, S1, R, R] -> [R, S1, R, R]_21' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_0' in strategy_name_list + assert '[R, S01, R, R] -> [R, S01, R, R]_1' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list if softmax_dim == 1: - assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_1' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_4' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_10' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list + assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_15' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_1' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list @run_on_environment_flag(name='AUTO_PARALLEL') diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py index 813651869..044aef19d 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_split_handler.py @@ -198,54 +198,54 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port if model_cls.__name__ == 'LinearSplitModel': if split_dim == 0: - assert '[R, R, R, S1]_0' in strategy_name_list - assert '[R, S0, R, S1]_1' in strategy_name_list - assert '[R, R, S0, S1]_2' in strategy_name_list - assert '[R, R, R, S0]_3' in strategy_name_list - assert '[R, S1, R, S0]_4' in strategy_name_list - assert '[R, R, S1, S0]_5' in strategy_name_list - assert '[R, R, R, R]_6' in strategy_name_list - assert '[R, S0, R, R]_7' in strategy_name_list - assert '[R, R, S0, R]_8' in strategy_name_list - assert '[R, R, R, R]_9' in strategy_name_list - assert '[R, S1, R, R]_10' in strategy_name_list - assert '[R, R, S1, R]_11' in strategy_name_list - assert '[R, R, R, S1]_12' in strategy_name_list - assert '[R, R, R, S0]_13' in strategy_name_list - assert '[R, R, R, R]_14' in strategy_name_list - assert '[R, R, R, R]_15' in strategy_name_list - assert '[R, R, R, S0]_16' in strategy_name_list - assert '[R, R, R, S1]_17' in strategy_name_list - assert '[R, R, R, R]_18' in strategy_name_list - assert '[R, S01, R, R]_19' in strategy_name_list - assert '[R, R, S01, R]_20' in strategy_name_list - assert '[R, R, R, R]_21' in strategy_name_list - assert '[R, R, R, S01]_22' in strategy_name_list + assert '[R, R, R, S1]_11' in strategy_name_list + assert '[R, S0, R, S1]_12' in strategy_name_list + assert '[R, R, S0, S1]_13' in strategy_name_list + assert '[R, R, R, S0]_14' in strategy_name_list + assert '[R, S1, R, S0]_15' in strategy_name_list + assert '[R, R, S1, S0]_16' in strategy_name_list + assert '[R, R, R, R]_17' in strategy_name_list + assert '[R, S0, R, R]_18' in strategy_name_list + assert '[R, R, S0, R]_19' in strategy_name_list + assert '[R, R, R, R]_20' in strategy_name_list + assert '[R, S1, R, R]_21' in strategy_name_list + assert '[R, R, S1, R]_22' in strategy_name_list + assert '[R, R, R, S1]_10' in strategy_name_list + assert '[R, R, R, S0]_9' in strategy_name_list + assert '[R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0]_6' in strategy_name_list + assert '[R, R, R, S1]_5' in strategy_name_list + assert '[R, R, R, R]_0' in strategy_name_list + assert '[R, S01, R, R]_1' in strategy_name_list + assert '[R, R, S01, R]_2' in strategy_name_list + assert '[R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01]_4' in strategy_name_list if split_dim == 1: - assert '[S0, R, R, S1]_0' in strategy_name_list - assert '[R, R, R, S1]_1' in strategy_name_list - assert '[R, R, S0, S1]_2' in strategy_name_list - assert '[S1, R, R, S0]_3' in strategy_name_list - assert '[R, R, R, S0]_4' in strategy_name_list - assert '[R, R, S1, S0]_5' in strategy_name_list - assert '[S0, R, R, R]_6' in strategy_name_list - assert '[R, R, R, R]_7' in strategy_name_list - assert '[R, R, S0, R]_8' in strategy_name_list - assert '[S1, R, R, R]_9' in strategy_name_list - assert '[R, R, R, R]_10' in strategy_name_list - assert '[R, R, S1, R]_11' in strategy_name_list + assert '[S0, R, R, S1]_11' in strategy_name_list assert '[R, R, R, S1]_12' in strategy_name_list - assert '[R, R, R, S0]_13' in strategy_name_list - assert '[R, R, R, R]_14' in strategy_name_list - assert '[R, R, R, R]_15' in strategy_name_list - assert '[R, R, R, S0]_16' in strategy_name_list - assert '[R, R, R, S1]_17' in strategy_name_list - assert '[S01, R, R, R]_18' in strategy_name_list - assert '[R, R, R, R]_19' in strategy_name_list - assert '[R, R, S01, R]_20' in strategy_name_list + assert '[R, R, S0, S1]_13' in strategy_name_list + assert '[S1, R, R, S0]_14' in strategy_name_list + assert '[R, R, R, S0]_15' in strategy_name_list + assert '[R, R, S1, S0]_16' in strategy_name_list + assert '[S0, R, R, R]_17' in strategy_name_list + assert '[R, R, R, R]_18' in strategy_name_list + assert '[R, R, S0, R]_19' in strategy_name_list + assert '[S1, R, R, R]_20' in strategy_name_list assert '[R, R, R, R]_21' in strategy_name_list - assert '[R, R, R, S01]_22' in strategy_name_list + assert '[R, R, S1, R]_22' in strategy_name_list + assert '[R, R, R, S1]_10' in strategy_name_list + assert '[R, R, R, S0]_9' in strategy_name_list + assert '[R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0]_6' in strategy_name_list + assert '[R, R, R, S1]_5' in strategy_name_list + assert '[S01, R, R, R]_0' in strategy_name_list + assert '[R, R, R, R]_1' in strategy_name_list + assert '[R, R, S01, R]_2' in strategy_name_list + assert '[R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01]_4' in strategy_name_list @run_on_environment_flag(name='AUTO_PARALLEL') diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py index d07d2f76c..8a96ac0d6 100644 --- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py @@ -196,54 +196,57 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port): if model_cls.__name__ == 'LinearViewModel': if tgt_shape == (32, 4, 64, 16, 4): - assert '[S0, R, R, S1] -> [S0, R, R, S1, R]_0' in strategy_name_list - assert '[R, S0, R, S1] -> FULLY REPLICATED_1' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, S0, S1, R]_2' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, S0, R]_3' in strategy_name_list - assert '[R, S1, R, S0] -> FULLY REPLICATED_4' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, S1, S0, R]_5' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R, R]_6' in strategy_name_list - assert '[R, S0, R, R] -> FULLY REPLICATED_7' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, S0, R, R]_8' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R, R]_9' in strategy_name_list - assert '[R, S1, R, R] -> FULLY REPLICATED_10' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, S1, R, R]_11' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1, R]_12' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0, R]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_14' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_15' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, S0, R]_16' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, S1, R]_17' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R, R]_18' in strategy_name_list - assert '[R, S01, R, R] -> FULLY REPLICATED_19' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, S01, R, R]_20' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R]_21' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, S01, R]_22' in strategy_name_list + for strategy in strategy_name_list: + print(strategy) + # print(strategy_name_list) + assert '[S0, R, R, S1] -> [S0, R, R, S1, R]_11' in strategy_name_list + assert '[R, S0, R, S1] -> FULLY REPLICATED_12' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, S0, S1, R]_13' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, S0, R]_14' in strategy_name_list + assert '[R, S1, R, S0] -> FULLY REPLICATED_15' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, S1, S0, R]_16' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R]_17' in strategy_name_list + assert '[R, S0, R, R] -> FULLY REPLICATED_18' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, S0, R, R]_19' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R]_20' in strategy_name_list + assert '[R, S1, R, R] -> FULLY REPLICATED_21' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, S1, R, R]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1, R]_10' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0, R]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, S0, R]_6' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, S1, R]_5' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R, R]_0' in strategy_name_list + assert '[R, S01, R, R] -> FULLY REPLICATED_1' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, S01, R, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, S01, R]_4' in strategy_name_list if tgt_shape == (8, 4, 4, 64, 16, 4): - assert '[S0, R, R, S1] -> [S0, R, R, R, S1, R]_0' in strategy_name_list - assert '[R, S0, R, S1] -> [R, S0, R, R, S1, R]_1' in strategy_name_list - assert '[R, R, S0, S1] -> [R, R, R, S0, S1, R]_2' in strategy_name_list - assert '[S1, R, R, S0] -> [S1, R, R, R, S0, R]_3' in strategy_name_list - assert '[R, S1, R, S0] -> [R, S1, R, R, S0, R]_4' in strategy_name_list - assert '[R, R, S1, S0] -> [R, R, R, S1, S0, R]_5' in strategy_name_list - assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_6' in strategy_name_list - assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_7' in strategy_name_list - assert '[R, R, S0, R] -> [R, R, R, S0, R, R]_8' in strategy_name_list - assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_9' in strategy_name_list - assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_10' in strategy_name_list - assert '[R, R, S1, R] -> [R, R, R, S1, R, R]_11' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_12' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_13' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_14' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_15' in strategy_name_list - assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_16' in strategy_name_list - assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_17' in strategy_name_list - assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_18' in strategy_name_list - assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_19' in strategy_name_list - assert '[R, R, S01, R] -> [R, R, R, S01, R, R]_20' in strategy_name_list - assert '[R, R, R, R] -> [R, R, R, R, R, R]_21' in strategy_name_list - assert '[R, R, R, S01] -> [R, R, R, R, S01, R]_22' in strategy_name_list + assert '[S0, R, R, S1] -> [S0, R, R, R, S1, R]_11' in strategy_name_list + assert '[R, S0, R, S1] -> [R, S0, R, R, S1, R]_12' in strategy_name_list + assert '[R, R, S0, S1] -> [R, R, R, S0, S1, R]_13' in strategy_name_list + assert '[S1, R, R, S0] -> [S1, R, R, R, S0, R]_14' in strategy_name_list + assert '[R, S1, R, S0] -> [R, S1, R, R, S0, R]_15' in strategy_name_list + assert '[R, R, S1, S0] -> [R, R, R, S1, S0, R]_16' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_17' in strategy_name_list + assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_18' in strategy_name_list + assert '[R, R, S0, R] -> [R, R, R, S0, R, R]_19' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_20' in strategy_name_list + assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_21' in strategy_name_list + assert '[R, R, S1, R] -> [R, R, R, S1, R, R]_22' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_10' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_9' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_7' in strategy_name_list + assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_6' in strategy_name_list + assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_5' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_0' in strategy_name_list + assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_1' in strategy_name_list + assert '[R, R, S01, R] -> [R, R, R, S01, R, R]_2' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R, R]_3' in strategy_name_list + assert '[R, R, R, S01] -> [R, R, R, R, S01, R]_4' in strategy_name_list @run_on_environment_flag(name='AUTO_PARALLEL')