mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] add non_split linear strategy (#2078)
* [autoparallel] add non_split linear stategy * polishpull/2083/head
parent
cf0268da93
commit
cdf537a648
|
@ -263,6 +263,9 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
||||||
# RS01 = RR x RS01
|
# RS01 = RR x RS01
|
||||||
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
|
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
|
||||||
|
|
||||||
|
# RR = RR x RR
|
||||||
|
strategies.append(self.non_split())
|
||||||
|
|
||||||
return strategies
|
return strategies
|
||||||
|
|
||||||
@ignore_sharding_exception
|
@ignore_sharding_exception
|
||||||
|
@ -665,6 +668,29 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
|
||||||
sharding_spec_mapping=sharding_spec_mapping,
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
communication_action_mapping=communication_action_mapping)
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
|
@ignore_sharding_exception
|
||||||
|
def non_split(self):
|
||||||
|
name = f'RR = RR x RR'
|
||||||
|
|
||||||
|
# get sharding spec
|
||||||
|
dim_partition_dict_mapping = {
|
||||||
|
"input": {},
|
||||||
|
"other": {},
|
||||||
|
"bias": {},
|
||||||
|
"output": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
# We don't have to do anything special for bias here, because
|
||||||
|
# the bias is already the same sharding spec as the output.
|
||||||
|
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||||
|
|
||||||
|
# get communication action
|
||||||
|
communication_action_mapping = {}
|
||||||
|
|
||||||
|
return self.get_sharding_strategy(name=name,
|
||||||
|
sharding_spec_mapping=sharding_spec_mapping,
|
||||||
|
communication_action_mapping=communication_action_mapping)
|
||||||
|
|
||||||
def validate(self) -> bool:
|
def validate(self) -> bool:
|
||||||
assert "input" in self.op_data
|
assert "input" in self.op_data
|
||||||
assert "other" in self.op_data
|
assert "other" in self.op_data
|
||||||
|
|
|
@ -204,9 +204,15 @@ class ShardingStrategy:
|
||||||
def _deepcopy_dict_vals(data: Dict):
|
def _deepcopy_dict_vals(data: Dict):
|
||||||
return {k: deepcopy(v) for k, v in data.items()}
|
return {k: deepcopy(v) for k, v in data.items()}
|
||||||
|
|
||||||
sharding_specs = _deepcopy_dict_vals(self.sharding_specs) if self.sharding_specs else None
|
sharding_specs = _deepcopy_dict_vals(self.sharding_specs) if self.sharding_specs is not None else None
|
||||||
communication_actions = _deepcopy_dict_vals(self.communication_actions) if self.communication_actions else None
|
# We need to deepcopy it when self.communication_actions is not None, instead of checking its __bool__ value.
|
||||||
resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs else None
|
# Consider the examples below:
|
||||||
|
# If self.communication_actions is an empty dictionary {}, then self.communication_actions is not None, but its __bool__ value is False.
|
||||||
|
# In this case, if we set None to the new object, program will crash when we try to access the communication_actions.items.
|
||||||
|
communication_actions = _deepcopy_dict_vals(
|
||||||
|
self.communication_actions) if self.communication_actions is not None else None
|
||||||
|
# same reason as communication_actions
|
||||||
|
resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs is not None else None
|
||||||
compute_cost = deepcopy(self.compute_cost)
|
compute_cost = deepcopy(self.compute_cost)
|
||||||
communication_cost = deepcopy(self.communication_cost)
|
communication_cost = deepcopy(self.communication_cost)
|
||||||
memory_cost = deepcopy(self.memory_cost)
|
memory_cost = deepcopy(self.memory_cost)
|
||||||
|
|
|
@ -45,11 +45,11 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
mesh_shape = (2, 2)
|
mesh_shape = (2, 2)
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||||
input = torch.rand(2, 2, 4, 16).cuda()
|
input = torch.rand(4, 4, 4, 16).cuda()
|
||||||
# the index of linear node in computation graph
|
# the index of linear node in computation graph
|
||||||
node_index = 3
|
node_index = 3
|
||||||
# strategy number of linear node
|
# strategy number of linear node
|
||||||
strategy_number = 10
|
strategy_number = 24
|
||||||
# construct input args
|
# construct input args
|
||||||
input_args = [input]
|
input_args = [input]
|
||||||
# construct meta arg names
|
# construct meta arg names
|
||||||
|
@ -63,7 +63,7 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
||||||
node_type='bias_module')
|
node_type='bias_module')
|
||||||
|
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
graph = tracer.trace(model, meta_args={"x": torch.rand(2, 2, 4, 16).to('meta')})
|
graph = tracer.trace(model, meta_args={"x": torch.rand(4, 4, 4, 16).to('meta')})
|
||||||
gm = ColoGraphModule(model, graph)
|
gm = ColoGraphModule(model, graph)
|
||||||
|
|
||||||
linear_mod_node = list(graph.nodes)[3]
|
linear_mod_node = list(graph.nodes)[3]
|
||||||
|
@ -81,9 +81,9 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
||||||
assert op_data.data is not None
|
assert op_data.data is not None
|
||||||
|
|
||||||
assert mapping['input'].name == "x"
|
assert mapping['input'].name == "x"
|
||||||
assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16])
|
assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16])
|
||||||
assert mapping['input'].type == OperationDataType.ARG
|
assert mapping['input'].type == OperationDataType.ARG
|
||||||
assert mapping['input'].logical_shape == torch.Size([16, 16])
|
assert mapping['input'].logical_shape == torch.Size([64, 16])
|
||||||
|
|
||||||
assert mapping['other'].name == "linear_weight"
|
assert mapping['other'].name == "linear_weight"
|
||||||
assert mapping['other'].data.shape == torch.Size([32, 16])
|
assert mapping['other'].data.shape == torch.Size([32, 16])
|
||||||
|
@ -93,21 +93,27 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
||||||
assert 'bias' not in mapping
|
assert 'bias' not in mapping
|
||||||
|
|
||||||
assert mapping['output'].name == "linear"
|
assert mapping['output'].name == "linear"
|
||||||
assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32])
|
assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32])
|
||||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||||
|
|
||||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||||
strategy_name_list = [val.name for val in strategies_vector]
|
strategy_name_list = [val.name for val in strategies_vector]
|
||||||
# one strategy will be converted to different physical sharding spec
|
|
||||||
assert len(strategy_name_list) > 8
|
|
||||||
|
|
||||||
# SS = SR x RS
|
# SS = SR x RS
|
||||||
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
|
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
|
||||||
|
assert 'S0S1 = S0R x RS1_1' in strategy_name_list
|
||||||
|
assert 'S0S1 = S0R x RS1_2' in strategy_name_list
|
||||||
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
|
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
|
||||||
|
assert 'S1S0 = S1R x RS0_1' in strategy_name_list
|
||||||
|
assert 'S1S0 = S1R x RS0_2' in strategy_name_list
|
||||||
|
|
||||||
# SR = SS x SR
|
# SR = SS x SR
|
||||||
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
|
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
|
||||||
|
assert 'S0R = S0S1 x S1R_1' in strategy_name_list
|
||||||
|
assert 'S0R = S0S1 x S1R_2' in strategy_name_list
|
||||||
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
|
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
|
||||||
|
assert 'S1R = S1S0 x S0R_1' in strategy_name_list
|
||||||
|
assert 'S1R = S1S0 x S0R_2' in strategy_name_list
|
||||||
|
|
||||||
# RS = RS x SS
|
# RS = RS x SS
|
||||||
assert 'RS0 = RS1 x S1S0' in strategy_name_list
|
assert 'RS0 = RS1 x S1S0' in strategy_name_list
|
||||||
|
@ -121,6 +127,20 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
||||||
assert 'RS0 = RR x RS0' in strategy_name_list
|
assert 'RS0 = RR x RS0' in strategy_name_list
|
||||||
assert 'RS1 = RR x RS1' in strategy_name_list
|
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||||
|
|
||||||
|
# S01R = S01R x RR
|
||||||
|
assert 'S01R = S01R x RR_0' in strategy_name_list
|
||||||
|
assert 'S01R = S01R x RR_1' in strategy_name_list
|
||||||
|
assert 'S01R = S01R x RR_2' in strategy_name_list
|
||||||
|
|
||||||
|
# RR = RS01 x S01R
|
||||||
|
assert 'RR = RS01 x S01R' in strategy_name_list
|
||||||
|
|
||||||
|
# RS01 = RR x RS01
|
||||||
|
assert 'RS01 = RR x RS01' in strategy_name_list
|
||||||
|
|
||||||
|
# RR = RR x RR
|
||||||
|
assert 'RR = RR x RR' in strategy_name_list
|
||||||
|
|
||||||
for strategy in strategies_vector:
|
for strategy in strategies_vector:
|
||||||
strategy: ShardingStrategy
|
strategy: ShardingStrategy
|
||||||
input_sharding_spec = strategy.get_sharding_spec_by_name('x')
|
input_sharding_spec = strategy.get_sharding_spec_by_name('x')
|
||||||
|
|
|
@ -33,11 +33,11 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
||||||
physical_mesh_id = torch.arange(0, 4)
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
mesh_shape = (2, 2)
|
mesh_shape = (2, 2)
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||||
input = torch.rand(2, 2, 4, 16).cuda()
|
input = torch.rand(4, 4, 4, 16).cuda()
|
||||||
# the index of linear node in computation graph
|
# the index of linear node in computation graph
|
||||||
node_index = 1
|
node_index = 1
|
||||||
# strategy number of linear node
|
# strategy number of linear node
|
||||||
strategy_number = 10
|
strategy_number = 24
|
||||||
# construct input args
|
# construct input args
|
||||||
input_args = [input]
|
input_args = [input]
|
||||||
# construct meta arg names
|
# construct meta arg names
|
||||||
|
@ -50,7 +50,7 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
||||||
meta_arg_names=meta_arg_names)
|
meta_arg_names=meta_arg_names)
|
||||||
|
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
|
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 4, 16).to('meta')})
|
||||||
gm = ColoGraphModule(model, graph)
|
gm = ColoGraphModule(model, graph)
|
||||||
|
|
||||||
linear_mod_node = list(graph.nodes)[1]
|
linear_mod_node = list(graph.nodes)[1]
|
||||||
|
@ -69,9 +69,9 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
||||||
assert op_data.data is not None
|
assert op_data.data is not None
|
||||||
|
|
||||||
assert mapping['input'].name == "input_1"
|
assert mapping['input'].name == "input_1"
|
||||||
assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16])
|
assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16])
|
||||||
assert mapping['input'].type == OperationDataType.ARG
|
assert mapping['input'].type == OperationDataType.ARG
|
||||||
assert mapping['input'].logical_shape == torch.Size([16, 16])
|
assert mapping['input'].logical_shape == torch.Size([64, 16])
|
||||||
|
|
||||||
assert mapping['other'].name == "weight"
|
assert mapping['other'].name == "weight"
|
||||||
assert mapping['other'].data.shape == torch.Size([32, 16])
|
assert mapping['other'].data.shape == torch.Size([32, 16])
|
||||||
|
@ -85,9 +85,9 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
||||||
assert mapping['bias'].logical_shape == torch.Size([32])
|
assert mapping['bias'].logical_shape == torch.Size([32])
|
||||||
|
|
||||||
assert mapping['output'].name == "_0"
|
assert mapping['output'].name == "_0"
|
||||||
assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32])
|
assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32])
|
||||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||||
assert mapping['output'].logical_shape == torch.Size([16, 32])
|
assert mapping['output'].logical_shape == torch.Size([64, 32])
|
||||||
|
|
||||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||||
strategy_name_list = [val.name for val in strategies_vector]
|
strategy_name_list = [val.name for val in strategies_vector]
|
||||||
|
@ -96,11 +96,19 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
||||||
|
|
||||||
# SS = SR x RS
|
# SS = SR x RS
|
||||||
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
|
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
|
||||||
|
assert 'S0S1 = S0R x RS1_1' in strategy_name_list
|
||||||
|
assert 'S0S1 = S0R x RS1_2' in strategy_name_list
|
||||||
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
|
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
|
||||||
|
assert 'S1S0 = S1R x RS0_1' in strategy_name_list
|
||||||
|
assert 'S1S0 = S1R x RS0_2' in strategy_name_list
|
||||||
|
|
||||||
# SR = SS x SR
|
# SR = SS x SR
|
||||||
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
|
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
|
||||||
|
assert 'S0R = S0S1 x S1R_1' in strategy_name_list
|
||||||
|
assert 'S0R = S0S1 x S1R_2' in strategy_name_list
|
||||||
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
|
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
|
||||||
|
assert 'S1R = S1S0 x S0R_1' in strategy_name_list
|
||||||
|
assert 'S1R = S1S0 x S0R_2' in strategy_name_list
|
||||||
|
|
||||||
# RS = RS x SS
|
# RS = RS x SS
|
||||||
assert 'RS0 = RS1 x S1S0' in strategy_name_list
|
assert 'RS0 = RS1 x S1S0' in strategy_name_list
|
||||||
|
@ -114,6 +122,20 @@ def check_linear_module_handler(rank, bias, world_size, port):
|
||||||
assert 'RS0 = RR x RS0' in strategy_name_list
|
assert 'RS0 = RR x RS0' in strategy_name_list
|
||||||
assert 'RS1 = RR x RS1' in strategy_name_list
|
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||||
|
|
||||||
|
# S01R = S01R x RR
|
||||||
|
assert 'S01R = S01R x RR_0' in strategy_name_list
|
||||||
|
assert 'S01R = S01R x RR_1' in strategy_name_list
|
||||||
|
assert 'S01R = S01R x RR_2' in strategy_name_list
|
||||||
|
|
||||||
|
# RR = RS01 x S01R
|
||||||
|
assert 'RR = RS01 x S01R' in strategy_name_list
|
||||||
|
|
||||||
|
# RS01 = RR x RS01
|
||||||
|
assert 'RS01 = RR x RS01' in strategy_name_list
|
||||||
|
|
||||||
|
# RR = RR x RR
|
||||||
|
assert 'RR = RR x RR' in strategy_name_list
|
||||||
|
|
||||||
for strategy in strategies_vector:
|
for strategy in strategies_vector:
|
||||||
strategy: ShardingStrategy
|
strategy: ShardingStrategy
|
||||||
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
|
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
|
||||||
|
@ -150,12 +172,12 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
||||||
mesh_shape = (2, 2)
|
mesh_shape = (2, 2)
|
||||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||||
|
|
||||||
input = torch.rand(2, 2, 4, 16).cuda()
|
input = torch.rand(4, 4, 4, 16).cuda()
|
||||||
other = torch.rand(32, 16).cuda()
|
other = torch.rand(32, 16).cuda()
|
||||||
# the index of linear node in computation graph
|
# the index of linear node in computation graph
|
||||||
node_index = 2
|
node_index = 2
|
||||||
# strategy number of linear node
|
# strategy number of linear node
|
||||||
strategy_number = 10
|
strategy_number = 24
|
||||||
# construct input args
|
# construct input args
|
||||||
input_args = [input, other]
|
input_args = [input, other]
|
||||||
# construct meta arg names
|
# construct meta arg names
|
||||||
|
@ -170,7 +192,7 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
graph = tracer.trace(model,
|
graph = tracer.trace(model,
|
||||||
meta_args={
|
meta_args={
|
||||||
"input": torch.rand(2, 2, 4, 16).to('meta'),
|
"input": torch.rand(4, 4, 4, 16).to('meta'),
|
||||||
'others': torch.rand(32, 16).to('meta')
|
'others': torch.rand(32, 16).to('meta')
|
||||||
})
|
})
|
||||||
gm = ColoGraphModule(model, graph)
|
gm = ColoGraphModule(model, graph)
|
||||||
|
@ -187,9 +209,9 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
||||||
mapping = handler.get_operation_data_mapping()
|
mapping = handler.get_operation_data_mapping()
|
||||||
|
|
||||||
assert mapping['input'].name == "input_1"
|
assert mapping['input'].name == "input_1"
|
||||||
assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16])
|
assert mapping['input'].data.shape == torch.Size([4, 4, 4, 16])
|
||||||
assert mapping['input'].type == OperationDataType.ARG
|
assert mapping['input'].type == OperationDataType.ARG
|
||||||
assert mapping['input'].logical_shape == torch.Size([16, 16])
|
assert mapping['input'].logical_shape == torch.Size([64, 16])
|
||||||
|
|
||||||
assert mapping['other'].name == "others"
|
assert mapping['other'].name == "others"
|
||||||
assert mapping['other'].data.shape == torch.Size([32, 16])
|
assert mapping['other'].data.shape == torch.Size([32, 16])
|
||||||
|
@ -203,7 +225,7 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
||||||
assert mapping['other'].logical_shape == torch.Size([16, 32])
|
assert mapping['other'].logical_shape == torch.Size([16, 32])
|
||||||
|
|
||||||
assert mapping['output'].name == "linear"
|
assert mapping['output'].name == "linear"
|
||||||
assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32])
|
assert mapping['output'].data.shape == torch.Size([4, 4, 4, 32])
|
||||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||||
|
|
||||||
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
|
||||||
|
@ -213,11 +235,19 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
||||||
|
|
||||||
# SS = SR x RS
|
# SS = SR x RS
|
||||||
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
|
assert 'S0S1 = S0R x RS1_0' in strategy_name_list
|
||||||
|
assert 'S0S1 = S0R x RS1_1' in strategy_name_list
|
||||||
|
assert 'S0S1 = S0R x RS1_2' in strategy_name_list
|
||||||
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
|
assert 'S1S0 = S1R x RS0_0' in strategy_name_list
|
||||||
|
assert 'S1S0 = S1R x RS0_1' in strategy_name_list
|
||||||
|
assert 'S1S0 = S1R x RS0_2' in strategy_name_list
|
||||||
|
|
||||||
# SR = SS x SR
|
# SR = SS x SR
|
||||||
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
|
assert 'S0R = S0S1 x S1R_0' in strategy_name_list
|
||||||
|
assert 'S0R = S0S1 x S1R_1' in strategy_name_list
|
||||||
|
assert 'S0R = S0S1 x S1R_2' in strategy_name_list
|
||||||
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
|
assert 'S1R = S1S0 x S0R_0' in strategy_name_list
|
||||||
|
assert 'S1R = S1S0 x S0R_1' in strategy_name_list
|
||||||
|
assert 'S1R = S1S0 x S0R_2' in strategy_name_list
|
||||||
|
|
||||||
# RS = RS x SS
|
# RS = RS x SS
|
||||||
assert 'RS0 = RS1 x S1S0' in strategy_name_list
|
assert 'RS0 = RS1 x S1S0' in strategy_name_list
|
||||||
|
@ -231,6 +261,20 @@ def check_linear_function_handler(rank, bias, world_size, port):
|
||||||
assert 'RS0 = RR x RS0' in strategy_name_list
|
assert 'RS0 = RR x RS0' in strategy_name_list
|
||||||
assert 'RS1 = RR x RS1' in strategy_name_list
|
assert 'RS1 = RR x RS1' in strategy_name_list
|
||||||
|
|
||||||
|
# S01R = S01R x RR
|
||||||
|
assert 'S01R = S01R x RR_0' in strategy_name_list
|
||||||
|
assert 'S01R = S01R x RR_1' in strategy_name_list
|
||||||
|
assert 'S01R = S01R x RR_2' in strategy_name_list
|
||||||
|
|
||||||
|
# RR = RS01 x S01R
|
||||||
|
assert 'RR = RS01 x S01R' in strategy_name_list
|
||||||
|
|
||||||
|
# RS01 = RR x RS01
|
||||||
|
assert 'RS01 = RR x RS01' in strategy_name_list
|
||||||
|
|
||||||
|
# RR = RR x RR
|
||||||
|
assert 'RR = RR x RR' in strategy_name_list
|
||||||
|
|
||||||
for strategy in strategies_vector:
|
for strategy in strategies_vector:
|
||||||
strategy: ShardingStrategy
|
strategy: ShardingStrategy
|
||||||
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
|
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
|
||||||
|
|
Loading…
Reference in New Issue