mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] add experimental view handler (#2011)
* [autoparallel] add experimental view handler * polish * polish * polish code * rename variablespull/2018/head
parent
d655eea515
commit
1438993113
|
@ -0,0 +1,4 @@
|
|||
from .view_generator import ViewGenerator
|
||||
from .view_handler import ViewHandler
|
||||
|
||||
__all__ = ['ViewGenerator', 'ViewHandler']
|
|
@ -0,0 +1,133 @@
|
|||
import copy
|
||||
from typing import List
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
|
||||
CommAction,
|
||||
CommType,
|
||||
MemoryCost,
|
||||
ShardingStrategy,
|
||||
TrainCycleItem,
|
||||
)
|
||||
from colossalai.auto_parallel.tensor_shard.utils import (
|
||||
check_keep_sharding_status,
|
||||
detect_reshape_mapping,
|
||||
infer_output_dim_partition_dict,
|
||||
)
|
||||
from colossalai.tensor.shape_consistency import CollectiveCommPattern
|
||||
from colossalai.tensor.sharding_spec import ShardingSpec
|
||||
|
||||
__all__ = ['ViewGenerator']
|
||||
|
||||
|
||||
class ViewGenerator(FollowingStrategyGenerator):
|
||||
"""
|
||||
ViewGenerator which deals with the sharding strategies of view op.
|
||||
"""
|
||||
|
||||
def validate(self) -> bool:
|
||||
return super().validate()
|
||||
|
||||
def update_compute_cost(self, strategy: ShardingStrategy):
|
||||
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
|
||||
strategy.compute_cost = compute_cost
|
||||
|
||||
def update_memory_cost(self, strategy: ShardingStrategy):
|
||||
'''
|
||||
Compute the memory cost per device with this specific strategy.
|
||||
'''
|
||||
forward_size_mapping = {
|
||||
'input': self._compute_size_in_bytes(strategy, "input"),
|
||||
'output': self._compute_size_in_bytes(strategy, "output")
|
||||
}
|
||||
|
||||
backward_size_mapping = copy.deepcopy(forward_size_mapping)
|
||||
backward_size_mapping.pop("output")
|
||||
# compute fwd cost incurred
|
||||
# fwd_cost = input + output
|
||||
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
|
||||
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
|
||||
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
|
||||
|
||||
# compute bwd cost incurred
|
||||
# bwd_cost = input_grad
|
||||
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
|
||||
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
|
||||
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
|
||||
|
||||
# compute total cost
|
||||
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
|
||||
parameter=fwd_parameter_cost + bwd_parameter_cost)
|
||||
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
|
||||
strategy.memory_cost = memory_cost
|
||||
|
||||
def collate_strategies(self) -> List[ShardingStrategy]:
|
||||
strategy_list = []
|
||||
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
|
||||
dim_partition_dict_mapping = {}
|
||||
communication_action_mapping = {}
|
||||
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
|
||||
|
||||
origin_shape = self.op_data['input'].data.shape
|
||||
tgt_shape = self.op_data['tgt_shape'].data
|
||||
|
||||
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
|
||||
|
||||
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
|
||||
keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict)
|
||||
|
||||
if keep_sharding_status:
|
||||
dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input,
|
||||
reshape_mapping_dict)
|
||||
else:
|
||||
dim_partition_dict_for_output = {}
|
||||
|
||||
dim_partition_dict_mapping = {
|
||||
"input": dim_partition_dict_for_input,
|
||||
"output": dim_partition_dict_for_output,
|
||||
}
|
||||
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
|
||||
|
||||
# add index into name to pass the duplicated check
|
||||
# we keep same strategies with different name for node merging, and it will not increase the searching space,
|
||||
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
|
||||
if keep_sharding_status:
|
||||
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
|
||||
else:
|
||||
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}'
|
||||
|
||||
# add comm action for converting input to fully replicated
|
||||
total_mesh_dim_list = []
|
||||
for mesh_dim_list in dim_partition_dict_for_input.values():
|
||||
total_mesh_dim_list.extend(mesh_dim_list)
|
||||
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
|
||||
if len(total_mesh_dim_list) == 1:
|
||||
total_mesh_dim_list = total_mesh_dim_list[0]
|
||||
input_comm_action = self.get_communication_action(
|
||||
sharding_spec=sharding_spec_mapping["input"],
|
||||
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
|
||||
logical_process_axis=total_mesh_dim_list,
|
||||
comm_type=CommType.BEFORE,
|
||||
arg_index=0)
|
||||
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
|
||||
|
||||
elif len(total_mesh_dim_list) >= 2:
|
||||
source_spec = sharding_spec_mapping["input"]
|
||||
target_spec = ShardingSpec(device_mesh=self.device_mesh,
|
||||
entire_shape=source_spec.entire_shape,
|
||||
dim_partition_dict={})
|
||||
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
|
||||
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
|
||||
|
||||
else:
|
||||
input_comm_action = None
|
||||
|
||||
if input_comm_action is not None:
|
||||
communication_action_mapping["input"] = input_comm_action
|
||||
|
||||
strategy = self.get_sharding_strategy(name=name,
|
||||
sharding_spec_mapping=sharding_spec_mapping,
|
||||
communication_action_mapping=communication_action_mapping)
|
||||
strategy_list.append(strategy)
|
||||
|
||||
return strategy_list
|
|
@ -0,0 +1,51 @@
|
|||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
|
||||
from ...sharding_strategy import OperationData, OperationDataType
|
||||
from ..node_handler import NodeHandler
|
||||
from ..registry import operator_registry
|
||||
from ..strategy import StrategyGenerator
|
||||
from .view_generator import ViewGenerator
|
||||
|
||||
__all__ = ['ViewHandler']
|
||||
|
||||
|
||||
@operator_registry.register(torch.Tensor.view)
|
||||
class ViewHandler(NodeHandler):
|
||||
"""
|
||||
A ViewHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
|
||||
"""
|
||||
|
||||
def get_strategy_generator(self) -> List[StrategyGenerator]:
|
||||
op_data_mapping = self.get_operation_data_mapping()
|
||||
generators = []
|
||||
generators.append(ViewGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
|
||||
return generators
|
||||
|
||||
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
|
||||
# use transposed shape for strategies
|
||||
# the strategies will be transformed back to its original shape in self.post_process
|
||||
|
||||
# check if the input operand is a parameter
|
||||
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
|
||||
data_type = OperationDataType.PARAM
|
||||
else:
|
||||
data_type = OperationDataType.ARG
|
||||
|
||||
input_data = self.node.args[0]._meta_data
|
||||
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
|
||||
|
||||
target_shape = self.node._meta_data.shape
|
||||
physical_shape_operand = OperationData(name='tgt_shape', type=OperationDataType.ARG, data=target_shape)
|
||||
|
||||
output_data = self.node._meta_data
|
||||
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
|
||||
|
||||
mapping = {
|
||||
"input": physical_input_operand,
|
||||
"tgt_shape": physical_shape_operand,
|
||||
"output": physical_output_operand
|
||||
}
|
||||
|
||||
return mapping
|
|
@ -51,6 +51,8 @@ class OperationData:
|
|||
"""
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.shape
|
||||
elif isinstance(data, torch.Size):
|
||||
return None
|
||||
elif isinstance(data, (tuple, list)):
|
||||
data_type = type(data)
|
||||
return data_type([_infer_logical_shape(d) for d in data])
|
||||
|
|
|
@ -82,7 +82,6 @@ class StrategiesConstructor:
|
|||
for node in self.nodes:
|
||||
strategies_vector = StrategiesVector(node)
|
||||
|
||||
print(node)
|
||||
if _check_no_strategy_for_node(node):
|
||||
no_strategy_node.append(node)
|
||||
pass
|
||||
|
|
|
@ -7,6 +7,7 @@ from .broadcast import (
|
|||
)
|
||||
from .factory import generate_resharding_costs, generate_sharding_spec
|
||||
from .misc import check_sharding_spec_validity, ignore_sharding_exception, pytree_map
|
||||
from .reshape import check_keep_sharding_status, detect_reshape_mapping, infer_output_dim_partition_dict
|
||||
from .sharding import (
|
||||
enumerate_all_possible_1d_sharding,
|
||||
enumerate_all_possible_2d_sharding,
|
||||
|
@ -19,5 +20,6 @@ __all__ = [
|
|||
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
|
||||
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
|
||||
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
|
||||
'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map'
|
||||
'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map',
|
||||
'detect_reshape_mapping', 'check_keep_sharding_status', 'infer_output_dim_partition_dict'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,168 @@
|
|||
from enum import Enum
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class PreviousStatus(Enum):
|
||||
"""
|
||||
This class shows the status of previous comparision.
|
||||
"""
|
||||
RESET = 0
|
||||
# ORIGIN means the dimension size of original tensor is larger in the previous comparision.
|
||||
ORIGIN = 1
|
||||
# TGT means the dimension size of target tensor is larger in the previous comparision.
|
||||
TGT = 2
|
||||
|
||||
|
||||
def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> Dict[Tuple[int], Tuple[int]]:
|
||||
"""
|
||||
This method is used to detect the reshape mapping between original tensor and target tensor.
|
||||
|
||||
Returns:
|
||||
reshape_mapping_dict: The dictionary shows how a tuple of origin dims(keys) mapping to the related
|
||||
target dims(values) during reshaping operation.
|
||||
Examples:
|
||||
import torch
|
||||
origin_shape = torch.Size([4, 4, 4])
|
||||
tgt_shape = torch.Size([2, 8, 2, 2])
|
||||
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
|
||||
print(reshape_mapping_dict)
|
||||
Output:
|
||||
{(2,): (3, 2), (1, 0): (1,), (0,): (0, 1)}
|
||||
"""
|
||||
|
||||
# reverse the shape object
|
||||
origin_shape = list(origin_shape)
|
||||
tgt_shape = list(tgt_shape)
|
||||
origin_shape.reverse()
|
||||
tgt_shape.reverse()
|
||||
|
||||
# initialize arguments
|
||||
reshape_mapping_dict = {}
|
||||
origin_len = len(origin_shape)
|
||||
tgt_len = len(tgt_shape)
|
||||
origin_index = 0
|
||||
tgt_index = 0
|
||||
original_dimension_size = origin_shape[origin_index]
|
||||
tgt_dimension_size = tgt_shape[tgt_index]
|
||||
tgt_dims = [tgt_len - tgt_index - 1]
|
||||
origin_dims = [origin_len - origin_index - 1]
|
||||
previous_label = PreviousStatus.RESET
|
||||
|
||||
while origin_index != len(origin_shape) or tgt_index != len(tgt_shape):
|
||||
if original_dimension_size == tgt_dimension_size:
|
||||
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
|
||||
origin_index += 1
|
||||
tgt_index += 1
|
||||
# the last step of loop should always end with condition
|
||||
# so we need to manually skip the preparation for next step
|
||||
# in the last step.
|
||||
if origin_index == len(origin_shape):
|
||||
continue
|
||||
original_dimension_size = origin_shape[origin_index]
|
||||
tgt_dimension_size = tgt_shape[tgt_index]
|
||||
origin_dims = [origin_len - origin_index - 1]
|
||||
tgt_dims = [tgt_len - tgt_index - 1]
|
||||
previous_label = PreviousStatus.RESET
|
||||
|
||||
elif original_dimension_size > tgt_dimension_size:
|
||||
tgt_index += 1
|
||||
|
||||
if previous_label == PreviousStatus.TGT:
|
||||
# if the target dimension size is larger in the previous comparision, which means
|
||||
# the origin dimension size has already accumulated larger than target dimension size, so
|
||||
# we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
|
||||
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
|
||||
original_dimension_size = original_dimension_size // tgt_dimension_size
|
||||
origin_dims = [origin_len - origin_index - 1]
|
||||
tgt_dimension_size = tgt_shape[tgt_index]
|
||||
tgt_dims = [tgt_len - tgt_index - 1, tgt_len - tgt_index]
|
||||
# reset the previous_label after offloading the origin dims and tgt dims
|
||||
previous_label = PreviousStatus.RESET
|
||||
else:
|
||||
# accumulate the tgt_dimension_size until tgt_dimension_size larger than original_dimension_size
|
||||
tgt_dimension_size *= tgt_shape[tgt_index]
|
||||
tgt_dims.append(tgt_len - tgt_index - 1)
|
||||
previous_label = PreviousStatus.ORIGIN
|
||||
|
||||
else:
|
||||
origin_index += 1
|
||||
|
||||
if previous_label == PreviousStatus.ORIGIN:
|
||||
# if the origin element is larger in the previous comparision, which means
|
||||
# the target element has already accumulated larger than origin element, so
|
||||
# we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
|
||||
reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
|
||||
tgt_dimension_size = tgt_dimension_size // original_dimension_size
|
||||
tgt_dims = [tgt_len - tgt_index - 1]
|
||||
original_dimension_size = origin_shape[origin_index]
|
||||
origin_dims = [origin_len - origin_index - 1, origin_len - origin_index]
|
||||
# reset the previous_label after offloading the origin dims and tgt dims
|
||||
previous_label = PreviousStatus.RESET
|
||||
else:
|
||||
# accumulate the original_dimension_size until original_dimension_size larger than tgt_dimension_size
|
||||
original_dimension_size *= origin_shape[origin_index]
|
||||
origin_dims.append(origin_len - origin_index - 1)
|
||||
previous_label = PreviousStatus.TGT
|
||||
|
||||
return reshape_mapping_dict
|
||||
|
||||
|
||||
def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
|
||||
reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> bool:
|
||||
"""
|
||||
This method is used to check whether the reshape operation could implement without converting
|
||||
the input to fully replicated status.
|
||||
|
||||
Rule:
|
||||
For a sharded dimension of input tensor, if it is not the minimum element of the input tuple,
|
||||
the function will return false.
|
||||
To illustrate this issue, there are two cases to analyse:
|
||||
1. no sharded dims in the input tuple: we could do the reshape operation safely just as the normal
|
||||
operation without distributed tensor.
|
||||
2. sharded dims in the input tuple: the sharded dim must be the minimum element, then during shape
|
||||
consistency process, torch.cat will be implemented on the sharded dim, and everything after the sharded
|
||||
dim get recovered.
|
||||
|
||||
Examples:
|
||||
# the second dimension of the input has been sharded.
|
||||
input_dim_partition_dict = {1: [1]}
|
||||
origin_shape = torch.Size([8, 4, 2])
|
||||
tgt_shape = torch.Size([2, 4, 8])
|
||||
reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
|
||||
# {(2, 1): (2,), (0,): (1, 0)}
|
||||
# the sharded dim of input is 1, which is the minimum element of the tuple (2, 1),
|
||||
# so we do not have to convert the input to fully replicated status.
|
||||
print(check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict))
|
||||
|
||||
Output:
|
||||
True
|
||||
"""
|
||||
sharded_dims = list(input_dim_partition_dict.keys())
|
||||
for input_dims in reshape_mapping_dict.keys():
|
||||
min_element = min(input_dims)
|
||||
for dim in input_dims:
|
||||
if dim in sharded_dims and dim is not min_element:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def infer_output_dim_partition_dict(input_dim_partition_dict: Dict[int, List[int]],
|
||||
reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> Dict[Tuple[int], Tuple[int]]:
|
||||
"""
|
||||
This method is used to infer the output dim partition dict for a reshape operation,
|
||||
given the input dim partition dict and reshape mapping dict.
|
||||
"""
|
||||
assert check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict), \
|
||||
'we only infer output dim partition dict for the reshape operation could keep sharding spec.'
|
||||
sharded_dims = list(input_dim_partition_dict.keys())
|
||||
output_dim_partition_dict = {}
|
||||
for input_dims, output_dims in reshape_mapping_dict.items():
|
||||
for dim in input_dims:
|
||||
if dim in sharded_dims:
|
||||
output_dim_partition_dict[min(output_dims)] = input_dim_partition_dict[dim]
|
||||
# we could break because input dims cannot contain two sharded dims, otherwise
|
||||
# the keep sharding status check will fail.
|
||||
break
|
||||
return output_dim_partition_dict
|
|
@ -0,0 +1,98 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
|
||||
from colossalai.auto_parallel.tensor_shard.node_handler.experimental import ViewHandler
|
||||
from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.fx import ColoGraphModule, ColoTracer
|
||||
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||
|
||||
|
||||
class ViewModel(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, input, other):
|
||||
conv_node = nn.functional.conv2d(input, other)
|
||||
reshape_node = conv_node.view(32, 4, 32, 32, 4)
|
||||
return reshape_node
|
||||
|
||||
|
||||
def test_view_handler():
|
||||
model = ViewModel()
|
||||
tracer = ColoTracer()
|
||||
# graph():
|
||||
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
|
||||
# %other : torch.Tensor [#users=1] = placeholder[target=other]
|
||||
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
|
||||
# %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {})
|
||||
# return view
|
||||
graph = tracer.trace(model,
|
||||
meta_args={
|
||||
"input": torch.rand(8, 8, 66, 66).to('meta'),
|
||||
"other": torch.rand(16, 8, 3, 3).to('meta'),
|
||||
})
|
||||
gm = ColoGraphModule(model, graph)
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
|
||||
conv_mod_node = list(graph.nodes)[2]
|
||||
view_node = list(graph.nodes)[3]
|
||||
view_strategies_vector = StrategiesVector(view_node)
|
||||
conv_strategies_vector = StrategiesVector(conv_mod_node)
|
||||
|
||||
# build handler
|
||||
conv_handler = ConvFunctionHandler(node=conv_mod_node,
|
||||
device_mesh=device_mesh,
|
||||
strategies_vector=conv_strategies_vector)
|
||||
conv_handler.register_strategy(compute_resharding_cost=False)
|
||||
setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector)
|
||||
view_handler = ViewHandler(node=view_node, device_mesh=device_mesh, strategies_vector=view_strategies_vector)
|
||||
|
||||
view_handler.register_strategy(compute_resharding_cost=False)
|
||||
|
||||
# check operation data mapping
|
||||
mapping = view_handler.get_operation_data_mapping()
|
||||
|
||||
for name, op_data in mapping.items():
|
||||
op_data: OperationData
|
||||
# make sure they have valid values
|
||||
assert op_data.data is not None
|
||||
|
||||
assert mapping['input'].name == "conv2d"
|
||||
assert mapping['input'].data.is_meta
|
||||
assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64])
|
||||
assert mapping['input'].type == OperationDataType.ARG
|
||||
assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64])
|
||||
|
||||
assert mapping['output'].name == "view"
|
||||
assert mapping['output'].data.is_meta
|
||||
assert mapping['output'].data.shape == torch.Size([32, 4, 32, 32, 4])
|
||||
assert mapping['output'].type == OperationDataType.OUTPUT
|
||||
|
||||
# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.
|
||||
assert len(view_strategies_vector) == len(conv_strategies_vector)
|
||||
strategy_name_list = [strategy.name for strategy in view_strategies_vector]
|
||||
assert '[S0, S1, R, R] -> FULLY REPLICATED_0' in strategy_name_list
|
||||
assert '[S1, S0, R, R] -> FULLY REPLICATED_1' in strategy_name_list
|
||||
assert '[S0, R, R, R] -> [S0, R, R, R, R]_2' in strategy_name_list
|
||||
assert '[S1, R, R, R] -> [S1, R, R, R, R]_3' in strategy_name_list
|
||||
assert '[S0, R, R, R] -> [S0, R, R, R, R]_4' in strategy_name_list
|
||||
assert '[S1, R, R, R] -> [S1, R, R, R, R]_5' in strategy_name_list
|
||||
assert '[R, S1, R, R] -> FULLY REPLICATED_6' in strategy_name_list
|
||||
assert '[R, S0, R, R] -> FULLY REPLICATED_7' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R, R]_8' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R, R]_9' in strategy_name_list
|
||||
assert '[R, S0, R, R] -> FULLY REPLICATED_10' in strategy_name_list
|
||||
assert '[R, S1, R, R] -> FULLY REPLICATED_11' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R, R]_12' in strategy_name_list
|
||||
assert '[S01, R, R, R] -> [S01, R, R, R, R]_13' in strategy_name_list
|
||||
assert '[R, R, R, R] -> [R, R, R, R, R]_14' in strategy_name_list
|
||||
assert '[R, S01, R, R] -> FULLY REPLICATED_15' in strategy_name_list
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_view_handler()
|
Loading…
Reference in New Issue