[autoparallel] distinguish different parallel strategies (#2699)

pull/2738/head^2
YuliangLiu0306 2023-02-15 22:28:28 +08:00 committed by GitHub
parent ae86a29e23
commit 1dc003c169
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 255 additions and 219 deletions

View File

@ -152,7 +152,10 @@ class LinearModuleHandler(MetaInfoModuleHandler):
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
generators.append( generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) LinearProjectionStrategyGenerator(op_data_mapping,
self.device_mesh,
linear_projection_type='linear',
solver_perference=self.solver_perference))
return generators return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]: def get_operation_data_mapping(self) -> Dict[str, OperationData]:

View File

@ -3,6 +3,7 @@ from ast import arg
from functools import reduce from functools import reduce
from typing import List from typing import List
from colossalai.auto_parallel.tensor_shard.options import SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType, CommType,
MemoryCost, MemoryCost,
@ -209,9 +210,14 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
def __init__(self, operation_data_mapping, device_mesh, linear_projection_type='linear'): def __init__(self,
operation_data_mapping,
device_mesh,
linear_projection_type='linear',
solver_perference=SolverPerference.STANDARD):
super().__init__(operation_data_mapping, device_mesh) super().__init__(operation_data_mapping, device_mesh)
self.linear_projection_type = linear_projection_type self.linear_projection_type = linear_projection_type
self.solver_perference = solver_perference
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy: def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
# C = AB # C = AB
@ -231,16 +237,22 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
total=fwd_compute_cost + bwd_compute_cost) total=fwd_compute_cost + bwd_compute_cost)
strategy.compute_cost = compute_cost strategy.compute_cost = compute_cost
def collate_strategies(self) -> List[ShardingStrategy]: def dp_strategies(self) -> List[ShardingStrategy]:
strategies = [] strategies = []
# SS = SR x RS # S01R = S01R x RR
strategies.append(self.split_lhs_space_rhs_space(0, 1)) strategies.append(self.split_lhs_1st_dim_1d(0, 1))
strategies.append(self.split_lhs_space_rhs_space(1, 0))
# SR = SS x SR return strategies
strategies.append(self.split_lhs_space_both_contract(0, 1))
strategies.append(self.split_lhs_space_both_contract(1, 0)) def tp_strategies(self) -> List[ShardingStrategy]:
strategies = []
# RR = RS01 x S01R
strategies.append(self.split_lhs_2nd_dim_1d(0, 1))
# RS01 = RR x RS01
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
# RS = RS x SS # RS = RS x SS
strategies.append(self.split_rhs_space_both_contract(0, 1)) strategies.append(self.split_rhs_space_both_contract(0, 1))
@ -254,20 +266,38 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
strategies.append(self.split_rhs_space_only(0)) strategies.append(self.split_rhs_space_only(0))
strategies.append(self.split_rhs_space_only(1)) strategies.append(self.split_rhs_space_only(1))
# S01R = S01R x RR return strategies
strategies.append(self.split_lhs_1st_dim_1d(0, 1))
# RR = RS01 x S01R def mix_strategies(self) -> List[ShardingStrategy]:
strategies.append(self.split_lhs_2nd_dim_1d(0, 1)) strategies = []
# RS01 = RR x RS01 # SS = SR x RS
strategies.append(self.split_rhs_2nd_dim_1d(0, 1)) strategies.append(self.split_lhs_space_rhs_space(0, 1))
strategies.append(self.split_lhs_space_rhs_space(1, 0))
# SR = SS x SR
strategies.append(self.split_lhs_space_both_contract(0, 1))
strategies.append(self.split_lhs_space_both_contract(1, 0))
# RR = RR x RR # RR = RR x RR
strategies.append(self.non_split()) strategies.append(self.non_split())
return strategies return strategies
def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []
if self.solver_perference == SolverPerference.STANDARD:
strategies.extend(self.dp_strategies())
strategies.extend(self.tp_strategies())
strategies.extend(self.mix_strategies())
elif self.solver_perference == SolverPerference.DP:
strategies.extend(self.dp_strategies())
elif self.solver_perference == SolverPerference.TP:
strategies.extend(self.tp_strategies())
return strategies
@ignore_sharding_exception @ignore_sharding_exception
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS # handle case SS = SR x RS

View File

@ -117,7 +117,7 @@ def check_attention_layer(rank, model_cls, world_size, port):
gm = GraphModule(model, graph, model.__class__.__name__) gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile() gm.recompile()
strategies_constructor = build_strategy_constructor(graph, device_mesh) strategies_constructor = build_strategy_constructor(graph, device_mesh, 'standard', 'replicated', 'standard')
solution = solve_solution(gm, strategies_constructor, memory_budget=-1) solution = solve_solution(gm, strategies_constructor, memory_budget=-1)
gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor) gm, sharding_spec_dicts = transform_to_sharded_model(gm, solution, device_mesh, strategies_constructor)
gm = ModuleWrapper(gm, *sharding_spec_dicts) gm = ModuleWrapper(gm, *sharding_spec_dicts)

View File

@ -243,79 +243,79 @@ def check_view_handler(rank, call_function, reshape_dims, model_cls, world_size,
if model_cls.__name__ == 'LinearReshapeModel': if model_cls.__name__ == 'LinearReshapeModel':
if reshape_dims == ((0, 2, 1, 3), (1, 2)): if reshape_dims == ((0, 2, 1, 3), (1, 2)):
assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list
assert '[R, S0, R, S1] -> [R, R, S0, S1]_1' in strategy_name_list assert '[R, S0, R, S1] -> [R, R, S0, S1]_12' in strategy_name_list
assert '[R, R, S0, S1] -> [R, S0, R, S1]_2' in strategy_name_list assert '[R, R, S0, S1] -> [R, S0, R, S1]_13' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list
assert '[R, S1, R, S0] -> [R, R, S1, S0]_4' in strategy_name_list assert '[R, S1, R, S0] -> [R, R, S1, S0]_15' in strategy_name_list
assert '[R, R, S1, S0] -> [R, S1, R, S0]_5' in strategy_name_list assert '[R, R, S1, S0] -> [R, S1, R, S0]_16' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list
assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list assert '[R, S0, R, R] -> [R, R, S0, R]_18' in strategy_name_list
assert '[R, R, S0, R] -> [R, S0, R, R]_8' in strategy_name_list assert '[R, R, S0, R] -> [R, S0, R, R]_19' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list
assert '[R, S1, R, R] -> [R, R, S1, R]_10' in strategy_name_list assert '[R, S1, R, R] -> [R, R, S1, R]_21' in strategy_name_list
assert '[R, R, S1, R] -> [R, S1, R, R]_11' in strategy_name_list assert '[R, R, S1, R] -> [R, S1, R, R]_22' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list
assert '[R, S01, R, R] -> [R, R, S01, R]_19' in strategy_name_list assert '[R, S01, R, R] -> [R, R, S01, R]_1' in strategy_name_list
assert '[R, R, S01, R] -> [R, S01, R, R]_20' in strategy_name_list assert '[R, R, S01, R] -> [R, S01, R, R]_2' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list
if reshape_dims == (2, 0, 1, 3): if reshape_dims == (2, 0, 1, 3):
assert '[S0, R, R, S1] -> [R, S0, R, S1]_0' in strategy_name_list assert '[S0, R, R, S1] -> [R, S0, R, S1]_11' in strategy_name_list
assert '[R, S0, R, S1] -> [R, R, S0, S1]_1' in strategy_name_list assert '[R, S0, R, S1] -> [R, R, S0, S1]_12' in strategy_name_list
assert '[R, R, S0, S1] -> [S0, R, R, S1]_2' in strategy_name_list assert '[R, R, S0, S1] -> [S0, R, R, S1]_13' in strategy_name_list
assert '[S1, R, R, S0] -> [R, S1, R, S0]_3' in strategy_name_list assert '[S1, R, R, S0] -> [R, S1, R, S0]_14' in strategy_name_list
assert '[R, S1, R, S0] -> [R, R, S1, S0]_4' in strategy_name_list assert '[R, S1, R, S0] -> [R, R, S1, S0]_15' in strategy_name_list
assert '[R, R, S1, S0] -> [S1, R, R, S0]_5' in strategy_name_list assert '[R, R, S1, S0] -> [S1, R, R, S0]_16' in strategy_name_list
assert '[S0, R, R, R] -> [R, S0, R, R]_6' in strategy_name_list assert '[S0, R, R, R] -> [R, S0, R, R]_17' in strategy_name_list
assert '[R, S0, R, R] -> [R, R, S0, R]_7' in strategy_name_list assert '[R, S0, R, R] -> [R, R, S0, R]_18' in strategy_name_list
assert '[R, R, S0, R] -> [S0, R, R, R]_8' in strategy_name_list assert '[R, R, S0, R] -> [S0, R, R, R]_19' in strategy_name_list
assert '[S1, R, R, R] -> [R, S1, R, R]_9' in strategy_name_list assert '[S1, R, R, R] -> [R, S1, R, R]_20' in strategy_name_list
assert '[R, S1, R, R] -> [R, R, S1, R]_10' in strategy_name_list assert '[R, S1, R, R] -> [R, R, S1, R]_21' in strategy_name_list
assert '[R, R, S1, R] -> [S1, R, R, R]_11' in strategy_name_list assert '[R, R, S1, R] -> [S1, R, R, R]_22' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list
assert '[S01, R, R, R] -> [R, S01, R, R]_18' in strategy_name_list assert '[S01, R, R, R] -> [R, S01, R, R]_0' in strategy_name_list
assert '[R, S01, R, R] -> [R, R, S01, R]_19' in strategy_name_list assert '[R, S01, R, R] -> [R, R, S01, R]_1' in strategy_name_list
assert '[R, R, S01, R] -> [S01, R, R, R]_20' in strategy_name_list assert '[R, R, S01, R] -> [S01, R, R, R]_2' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list
if reshape_dims == (1, 3): if reshape_dims == (1, 3):
assert '[S0, R, R, S1] -> [S0, S1, R, R]_0' in strategy_name_list assert '[S0, R, R, S1] -> [S0, S1, R, R]_11' in strategy_name_list
assert '[R, S0, R, S1] -> [R, S1, R, S0]_1' in strategy_name_list assert '[R, S0, R, S1] -> [R, S1, R, S0]_12' in strategy_name_list
assert '[R, R, S0, S1] -> [R, S1, S0, R]_2' in strategy_name_list assert '[R, R, S0, S1] -> [R, S1, S0, R]_13' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, S0, R, R]_3' in strategy_name_list assert '[S1, R, R, S0] -> [S1, S0, R, R]_14' in strategy_name_list
assert '[R, S1, R, S0] -> [R, S0, R, S1]_4' in strategy_name_list assert '[R, S1, R, S0] -> [R, S0, R, S1]_15' in strategy_name_list
assert '[R, R, S1, S0] -> [R, S0, S1, R]_5' in strategy_name_list assert '[R, R, S1, S0] -> [R, S0, S1, R]_16' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list
assert '[R, S0, R, R] -> [R, R, R, S0]_7' in strategy_name_list assert '[R, S0, R, R] -> [R, R, R, S0]_18' in strategy_name_list
assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list
assert '[R, S1, R, R] -> [R, R, R, S1]_10' in strategy_name_list assert '[R, S1, R, R] -> [R, R, R, S1]_21' in strategy_name_list
assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1, R, R]_12' in strategy_name_list assert '[R, R, R, S1] -> [R, S1, R, R]_10' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0, R, R]_13' in strategy_name_list assert '[R, R, R, S0] -> [R, S0, R, R]_9' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
assert '[R, R, R, S0] -> [R, S0, R, R]_16' in strategy_name_list assert '[R, R, R, S0] -> [R, S0, R, R]_6' in strategy_name_list
assert '[R, R, R, S1] -> [R, S1, R, R]_17' in strategy_name_list assert '[R, R, R, S1] -> [R, S1, R, R]_5' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list
assert '[R, S01, R, R] -> [R, R, R, S01]_19' in strategy_name_list assert '[R, S01, R, R] -> [R, R, R, S01]_1' in strategy_name_list
assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list
assert '[R, R, R, S01] -> [R, S01, R, R]_22' in strategy_name_list assert '[R, R, R, S01] -> [R, S01, R, R]_4' in strategy_name_list
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')

View File

@ -117,54 +117,54 @@ def check_split_handler(rank, softmax_dim, model_cls, world_size, port):
strategy_name_list = [strategy.name for strategy in split_strategies_vector] strategy_name_list = [strategy.name for strategy in split_strategies_vector]
if softmax_dim == 0: if softmax_dim == 0:
assert '[R, R, R, S1] -> [R, R, R, S1]_0' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_11' in strategy_name_list
assert '[R, S0, R, S1] -> [R, S0, R, S1]_1' in strategy_name_list assert '[R, S0, R, S1] -> [R, S0, R, S1]_12' in strategy_name_list
assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_3' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0]_14' in strategy_name_list
assert '[R, S1, R, S0] -> [R, S1, R, S0]_4' in strategy_name_list assert '[R, S1, R, S0] -> [R, S1, R, S0]_15' in strategy_name_list
assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_6' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_17' in strategy_name_list
assert '[R, S0, R, R] -> [R, S0, R, R]_7' in strategy_name_list assert '[R, S0, R, R] -> [R, S0, R, R]_18' in strategy_name_list
assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_9' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_20' in strategy_name_list
assert '[R, S1, R, R] -> [R, S1, R, R]_10' in strategy_name_list assert '[R, S1, R, R] -> [R, S1, R, R]_21' in strategy_name_list
assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_0' in strategy_name_list
assert '[R, S01, R, R] -> [R, S01, R, R]_19' in strategy_name_list assert '[R, S01, R, R] -> [R, S01, R, R]_1' in strategy_name_list
assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list
if softmax_dim == 1: if softmax_dim == 1:
assert '[S0, R, R, S1] -> [S0, R, R, S1]_0' in strategy_name_list assert '[S0, R, R, S1] -> [S0, R, R, S1]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_1' in strategy_name_list
assert '[R, R, S0, S1] -> [R, R, S0, S1]_2' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, R, R, S0]_3' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_4' in strategy_name_list
assert '[R, R, S1, S0] -> [R, R, S1, S0]_5' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R]_6' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
assert '[R, R, S0, R] -> [R, R, S0, R]_8' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R]_9' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_10' in strategy_name_list
assert '[R, R, S1, R] -> [R, R, S1, R]_11' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1]_12' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_13' in strategy_name_list assert '[R, R, S0, S1] -> [R, R, S0, S1]_13' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_14' in strategy_name_list assert '[S1, R, R, S0] -> [S1, R, R, S0]_14' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_15' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0]_15' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_16' in strategy_name_list assert '[R, R, S1, S0] -> [R, R, S1, S0]_16' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_17' in strategy_name_list assert '[S0, R, R, R] -> [S0, R, R, R]_17' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R]_18' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_18' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_19' in strategy_name_list assert '[R, R, S0, R] -> [R, R, S0, R]_19' in strategy_name_list
assert '[R, R, S01, R] -> [R, R, S01, R]_20' in strategy_name_list assert '[S1, R, R, R] -> [S1, R, R, R]_20' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01]_22' in strategy_name_list assert '[R, R, S1, R] -> [R, R, S1, R]_22' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_10' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_9' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_7' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0]_6' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1]_5' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R]_0' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_1' in strategy_name_list
assert '[R, R, S01, R] -> [R, R, S01, R]_2' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R]_3' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01]_4' in strategy_name_list
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')

View File

@ -198,54 +198,54 @@ def check_split_handler(rank, split_size, split_dim, model_cls, world_size, port
if model_cls.__name__ == 'LinearSplitModel': if model_cls.__name__ == 'LinearSplitModel':
if split_dim == 0: if split_dim == 0:
assert '[R, R, R, S1]_0' in strategy_name_list assert '[R, R, R, S1]_11' in strategy_name_list
assert '[R, S0, R, S1]_1' in strategy_name_list assert '[R, S0, R, S1]_12' in strategy_name_list
assert '[R, R, S0, S1]_2' in strategy_name_list assert '[R, R, S0, S1]_13' in strategy_name_list
assert '[R, R, R, S0]_3' in strategy_name_list assert '[R, R, R, S0]_14' in strategy_name_list
assert '[R, S1, R, S0]_4' in strategy_name_list assert '[R, S1, R, S0]_15' in strategy_name_list
assert '[R, R, S1, S0]_5' in strategy_name_list assert '[R, R, S1, S0]_16' in strategy_name_list
assert '[R, R, R, R]_6' in strategy_name_list assert '[R, R, R, R]_17' in strategy_name_list
assert '[R, S0, R, R]_7' in strategy_name_list assert '[R, S0, R, R]_18' in strategy_name_list
assert '[R, R, S0, R]_8' in strategy_name_list assert '[R, R, S0, R]_19' in strategy_name_list
assert '[R, R, R, R]_9' in strategy_name_list assert '[R, R, R, R]_20' in strategy_name_list
assert '[R, S1, R, R]_10' in strategy_name_list assert '[R, S1, R, R]_21' in strategy_name_list
assert '[R, R, S1, R]_11' in strategy_name_list assert '[R, R, S1, R]_22' in strategy_name_list
assert '[R, R, R, S1]_12' in strategy_name_list assert '[R, R, R, S1]_10' in strategy_name_list
assert '[R, R, R, S0]_13' in strategy_name_list assert '[R, R, R, S0]_9' in strategy_name_list
assert '[R, R, R, R]_14' in strategy_name_list assert '[R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R]_15' in strategy_name_list assert '[R, R, R, R]_7' in strategy_name_list
assert '[R, R, R, S0]_16' in strategy_name_list assert '[R, R, R, S0]_6' in strategy_name_list
assert '[R, R, R, S1]_17' in strategy_name_list assert '[R, R, R, S1]_5' in strategy_name_list
assert '[R, R, R, R]_18' in strategy_name_list assert '[R, R, R, R]_0' in strategy_name_list
assert '[R, S01, R, R]_19' in strategy_name_list assert '[R, S01, R, R]_1' in strategy_name_list
assert '[R, R, S01, R]_20' in strategy_name_list assert '[R, R, S01, R]_2' in strategy_name_list
assert '[R, R, R, R]_21' in strategy_name_list assert '[R, R, R, R]_3' in strategy_name_list
assert '[R, R, R, S01]_22' in strategy_name_list assert '[R, R, R, S01]_4' in strategy_name_list
if split_dim == 1: if split_dim == 1:
assert '[S0, R, R, S1]_0' in strategy_name_list assert '[S0, R, R, S1]_11' in strategy_name_list
assert '[R, R, R, S1]_1' in strategy_name_list
assert '[R, R, S0, S1]_2' in strategy_name_list
assert '[S1, R, R, S0]_3' in strategy_name_list
assert '[R, R, R, S0]_4' in strategy_name_list
assert '[R, R, S1, S0]_5' in strategy_name_list
assert '[S0, R, R, R]_6' in strategy_name_list
assert '[R, R, R, R]_7' in strategy_name_list
assert '[R, R, S0, R]_8' in strategy_name_list
assert '[S1, R, R, R]_9' in strategy_name_list
assert '[R, R, R, R]_10' in strategy_name_list
assert '[R, R, S1, R]_11' in strategy_name_list
assert '[R, R, R, S1]_12' in strategy_name_list assert '[R, R, R, S1]_12' in strategy_name_list
assert '[R, R, R, S0]_13' in strategy_name_list assert '[R, R, S0, S1]_13' in strategy_name_list
assert '[R, R, R, R]_14' in strategy_name_list assert '[S1, R, R, S0]_14' in strategy_name_list
assert '[R, R, R, R]_15' in strategy_name_list assert '[R, R, R, S0]_15' in strategy_name_list
assert '[R, R, R, S0]_16' in strategy_name_list assert '[R, R, S1, S0]_16' in strategy_name_list
assert '[R, R, R, S1]_17' in strategy_name_list assert '[S0, R, R, R]_17' in strategy_name_list
assert '[S01, R, R, R]_18' in strategy_name_list assert '[R, R, R, R]_18' in strategy_name_list
assert '[R, R, R, R]_19' in strategy_name_list assert '[R, R, S0, R]_19' in strategy_name_list
assert '[R, R, S01, R]_20' in strategy_name_list assert '[S1, R, R, R]_20' in strategy_name_list
assert '[R, R, R, R]_21' in strategy_name_list assert '[R, R, R, R]_21' in strategy_name_list
assert '[R, R, R, S01]_22' in strategy_name_list assert '[R, R, S1, R]_22' in strategy_name_list
assert '[R, R, R, S1]_10' in strategy_name_list
assert '[R, R, R, S0]_9' in strategy_name_list
assert '[R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R]_7' in strategy_name_list
assert '[R, R, R, S0]_6' in strategy_name_list
assert '[R, R, R, S1]_5' in strategy_name_list
assert '[S01, R, R, R]_0' in strategy_name_list
assert '[R, R, R, R]_1' in strategy_name_list
assert '[R, R, S01, R]_2' in strategy_name_list
assert '[R, R, R, R]_3' in strategy_name_list
assert '[R, R, R, S01]_4' in strategy_name_list
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')

View File

@ -196,54 +196,57 @@ def check_view_handler(rank, tgt_shape, model_cls, world_size, port):
if model_cls.__name__ == 'LinearViewModel': if model_cls.__name__ == 'LinearViewModel':
if tgt_shape == (32, 4, 64, 16, 4): if tgt_shape == (32, 4, 64, 16, 4):
assert '[S0, R, R, S1] -> [S0, R, R, S1, R]_0' in strategy_name_list for strategy in strategy_name_list:
assert '[R, S0, R, S1] -> FULLY REPLICATED_1' in strategy_name_list print(strategy)
assert '[R, R, S0, S1] -> [R, R, S0, S1, R]_2' in strategy_name_list # print(strategy_name_list)
assert '[S1, R, R, S0] -> [S1, R, R, S0, R]_3' in strategy_name_list assert '[S0, R, R, S1] -> [S0, R, R, S1, R]_11' in strategy_name_list
assert '[R, S1, R, S0] -> FULLY REPLICATED_4' in strategy_name_list assert '[R, S0, R, S1] -> FULLY REPLICATED_12' in strategy_name_list
assert '[R, R, S1, S0] -> [R, R, S1, S0, R]_5' in strategy_name_list assert '[R, R, S0, S1] -> [R, R, S0, S1, R]_13' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R, R]_6' in strategy_name_list assert '[S1, R, R, S0] -> [S1, R, R, S0, R]_14' in strategy_name_list
assert '[R, S0, R, R] -> FULLY REPLICATED_7' in strategy_name_list assert '[R, S1, R, S0] -> FULLY REPLICATED_15' in strategy_name_list
assert '[R, R, S0, R] -> [R, R, S0, R, R]_8' in strategy_name_list assert '[R, R, S1, S0] -> [R, R, S1, S0, R]_16' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R, R]_9' in strategy_name_list assert '[S0, R, R, R] -> [S0, R, R, R, R]_17' in strategy_name_list
assert '[R, S1, R, R] -> FULLY REPLICATED_10' in strategy_name_list assert '[R, S0, R, R] -> FULLY REPLICATED_18' in strategy_name_list
assert '[R, R, S1, R] -> [R, R, S1, R, R]_11' in strategy_name_list assert '[R, R, S0, R] -> [R, R, S0, R, R]_19' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1, R]_12' in strategy_name_list assert '[S1, R, R, R] -> [S1, R, R, R, R]_20' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0, R]_13' in strategy_name_list assert '[R, S1, R, R] -> FULLY REPLICATED_21' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R]_14' in strategy_name_list assert '[R, R, S1, R] -> [R, R, S1, R, R]_22' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R]_15' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1, R]_10' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, S0, R]_16' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0, R]_9' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, S1, R]_17' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R, R]_8' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R, R]_18' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R, R]_7' in strategy_name_list
assert '[R, S01, R, R] -> FULLY REPLICATED_19' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, S0, R]_6' in strategy_name_list
assert '[R, R, S01, R] -> [R, R, S01, R, R]_20' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, S1, R]_5' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R]_21' in strategy_name_list assert '[S01, R, R, R] -> [S01, R, R, R, R]_0' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01, R]_22' in strategy_name_list assert '[R, S01, R, R] -> FULLY REPLICATED_1' in strategy_name_list
assert '[R, R, S01, R] -> [R, R, S01, R, R]_2' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R]_3' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, S01, R]_4' in strategy_name_list
if tgt_shape == (8, 4, 4, 64, 16, 4): if tgt_shape == (8, 4, 4, 64, 16, 4):
assert '[S0, R, R, S1] -> [S0, R, R, R, S1, R]_0' in strategy_name_list assert '[S0, R, R, S1] -> [S0, R, R, R, S1, R]_11' in strategy_name_list
assert '[R, S0, R, S1] -> [R, S0, R, R, S1, R]_1' in strategy_name_list assert '[R, S0, R, S1] -> [R, S0, R, R, S1, R]_12' in strategy_name_list
assert '[R, R, S0, S1] -> [R, R, R, S0, S1, R]_2' in strategy_name_list assert '[R, R, S0, S1] -> [R, R, R, S0, S1, R]_13' in strategy_name_list
assert '[S1, R, R, S0] -> [S1, R, R, R, S0, R]_3' in strategy_name_list assert '[S1, R, R, S0] -> [S1, R, R, R, S0, R]_14' in strategy_name_list
assert '[R, S1, R, S0] -> [R, S1, R, R, S0, R]_4' in strategy_name_list assert '[R, S1, R, S0] -> [R, S1, R, R, S0, R]_15' in strategy_name_list
assert '[R, R, S1, S0] -> [R, R, R, S1, S0, R]_5' in strategy_name_list assert '[R, R, S1, S0] -> [R, R, R, S1, S0, R]_16' in strategy_name_list
assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_6' in strategy_name_list assert '[S0, R, R, R] -> [S0, R, R, R, R, R]_17' in strategy_name_list
assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_7' in strategy_name_list assert '[R, S0, R, R] -> [R, S0, R, R, R, R]_18' in strategy_name_list
assert '[R, R, S0, R] -> [R, R, R, S0, R, R]_8' in strategy_name_list assert '[R, R, S0, R] -> [R, R, R, S0, R, R]_19' in strategy_name_list
assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_9' in strategy_name_list assert '[S1, R, R, R] -> [S1, R, R, R, R, R]_20' in strategy_name_list
assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_10' in strategy_name_list assert '[R, S1, R, R] -> [R, S1, R, R, R, R]_21' in strategy_name_list
assert '[R, R, S1, R] -> [R, R, R, S1, R, R]_11' in strategy_name_list assert '[R, R, S1, R] -> [R, R, R, S1, R, R]_22' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_12' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_10' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_13' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_9' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R, R]_14' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R, R, R]_8' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R, R]_15' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R, R, R]_7' in strategy_name_list
assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_16' in strategy_name_list assert '[R, R, R, S0] -> [R, R, R, R, S0, R]_6' in strategy_name_list
assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_17' in strategy_name_list assert '[R, R, R, S1] -> [R, R, R, R, S1, R]_5' in strategy_name_list
assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_18' in strategy_name_list assert '[S01, R, R, R] -> [S01, R, R, R, R, R]_0' in strategy_name_list
assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_19' in strategy_name_list assert '[R, S01, R, R] -> [R, S01, R, R, R, R]_1' in strategy_name_list
assert '[R, R, S01, R] -> [R, R, R, S01, R, R]_20' in strategy_name_list assert '[R, R, S01, R] -> [R, R, R, S01, R, R]_2' in strategy_name_list
assert '[R, R, R, R] -> [R, R, R, R, R, R]_21' in strategy_name_list assert '[R, R, R, R] -> [R, R, R, R, R, R]_3' in strategy_name_list
assert '[R, R, R, S01] -> [R, R, R, R, S01, R]_22' in strategy_name_list assert '[R, R, R, S01] -> [R, R, R, R, S01, R]_4' in strategy_name_list
@run_on_environment_flag(name='AUTO_PARALLEL') @run_on_environment_flag(name='AUTO_PARALLEL')