From 14389931138d1397ca3e070f6b0cee6d685a2b6d Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Thu, 24 Nov 2022 11:34:41 +0800 Subject: [PATCH] [autoparallel] add experimental view handler (#2011) * [autoparallel] add experimental view handler * polish * polish * polish code * rename variables --- .../node_handler/experimental/__init__.py | 4 + .../experimental/view_generator.py | 133 ++++++++++++++ .../node_handler/experimental/view_handler.py | 51 ++++++ .../tensor_shard/sharding_strategy.py | 2 + .../solver/strategies_constructor.py | 1 - .../tensor_shard/utils/__init__.py | 4 +- .../tensor_shard/utils/reshape.py | 168 ++++++++++++++++++ .../test_node_handler/test_view_handler.py | 98 ++++++++++ 8 files changed, 459 insertions(+), 2 deletions(-) create mode 100644 colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py create mode 100644 colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_generator.py create mode 100644 colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py create mode 100644 colossalai/auto_parallel/tensor_shard/utils/reshape.py create mode 100644 tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py new file mode 100644 index 000000000..7f644c0e1 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py @@ -0,0 +1,4 @@ +from .view_generator import ViewGenerator +from .view_handler import ViewHandler + +__all__ = ['ViewGenerator', 'ViewHandler'] diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_generator.py new file mode 100644 index 000000000..cdfa8b4eb --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_generator.py @@ -0,0 +1,133 @@ +import copy +from typing import List + +from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator +from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( + CommAction, + CommType, + MemoryCost, + ShardingStrategy, + TrainCycleItem, +) +from colossalai.auto_parallel.tensor_shard.utils import ( + check_keep_sharding_status, + detect_reshape_mapping, + infer_output_dim_partition_dict, +) +from colossalai.tensor.shape_consistency import CollectiveCommPattern +from colossalai.tensor.sharding_spec import ShardingSpec + +__all__ = ['ViewGenerator'] + + +class ViewGenerator(FollowingStrategyGenerator): + """ + ViewGenerator which deals with the sharding strategies of view op. + """ + + def validate(self) -> bool: + return super().validate() + + def update_compute_cost(self, strategy: ShardingStrategy): + compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) + strategy.compute_cost = compute_cost + + def update_memory_cost(self, strategy: ShardingStrategy): + ''' + Compute the memory cost per device with this specific strategy. + ''' + forward_size_mapping = { + 'input': self._compute_size_in_bytes(strategy, "input"), + 'output': self._compute_size_in_bytes(strategy, "output") + } + + backward_size_mapping = copy.deepcopy(forward_size_mapping) + backward_size_mapping.pop("output") + # compute fwd cost incurred + # fwd_cost = input + output + fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)]) + fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)]) + fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost) + + # compute bwd cost incurred + # bwd_cost = input_grad + bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)]) + bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)]) + bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost) + + # compute total cost + total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, + parameter=fwd_parameter_cost + bwd_parameter_cost) + memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) + strategy.memory_cost = memory_cost + + def collate_strategies(self) -> List[ShardingStrategy]: + strategy_list = [] + for index, strategy in enumerate(self.predecessor_node.strategies_vector): + dim_partition_dict_mapping = {} + communication_action_mapping = {} + input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]] + + origin_shape = self.op_data['input'].data.shape + tgt_shape = self.op_data['tgt_shape'].data + + reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape) + + dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict + keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict) + + if keep_sharding_status: + dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input, + reshape_mapping_dict) + else: + dim_partition_dict_for_output = {} + + dim_partition_dict_mapping = { + "input": dim_partition_dict_for_input, + "output": dim_partition_dict_for_output, + } + sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) + + # add index into name to pass the duplicated check + # we keep same strategies with different name for node merging, and it will not increase the searching space, + # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node. + if keep_sharding_status: + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}' + else: + name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}' + + # add comm action for converting input to fully replicated + total_mesh_dim_list = [] + for mesh_dim_list in dim_partition_dict_for_input.values(): + total_mesh_dim_list.extend(mesh_dim_list) + # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis. + if len(total_mesh_dim_list) == 1: + total_mesh_dim_list = total_mesh_dim_list[0] + input_comm_action = self.get_communication_action( + sharding_spec=sharding_spec_mapping["input"], + communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, + logical_process_axis=total_mesh_dim_list, + comm_type=CommType.BEFORE, + arg_index=0) + input_comm_action.comm_spec.gather_dim = total_mesh_dim_list + + elif len(total_mesh_dim_list) >= 2: + source_spec = sharding_spec_mapping["input"] + target_spec = ShardingSpec(device_mesh=self.device_mesh, + entire_shape=source_spec.entire_shape, + dim_partition_dict={}) + comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec} + input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0) + + else: + input_comm_action = None + + if input_comm_action is not None: + communication_action_mapping["input"] = input_comm_action + + strategy = self.get_sharding_strategy(name=name, + sharding_spec_mapping=sharding_spec_mapping, + communication_action_mapping=communication_action_mapping) + strategy_list.append(strategy) + + return strategy_list diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py new file mode 100644 index 000000000..bab4e0d76 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py @@ -0,0 +1,51 @@ +from typing import Dict, List + +import torch + +from ...sharding_strategy import OperationData, OperationDataType +from ..node_handler import NodeHandler +from ..registry import operator_registry +from ..strategy import StrategyGenerator +from .view_generator import ViewGenerator + +__all__ = ['ViewHandler'] + + +@operator_registry.register(torch.Tensor.view) +class ViewHandler(NodeHandler): + """ + A ViewHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape. + """ + + def get_strategy_generator(self) -> List[StrategyGenerator]: + op_data_mapping = self.get_operation_data_mapping() + generators = [] + generators.append(ViewGenerator(op_data_mapping, self.device_mesh, self.node.args[0])) + return generators + + def get_operation_data_mapping(self) -> Dict[str, OperationData]: + # use transposed shape for strategies + # the strategies will be transformed back to its original shape in self.post_process + + # check if the input operand is a parameter + if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter): + data_type = OperationDataType.PARAM + else: + data_type = OperationDataType.ARG + + input_data = self.node.args[0]._meta_data + physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data) + + target_shape = self.node._meta_data.shape + physical_shape_operand = OperationData(name='tgt_shape', type=OperationDataType.ARG, data=target_shape) + + output_data = self.node._meta_data + physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data) + + mapping = { + "input": physical_input_operand, + "tgt_shape": physical_shape_operand, + "output": physical_output_operand + } + + return mapping diff --git a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py index efe484917..bbf4215d9 100644 --- a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py +++ b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py @@ -51,6 +51,8 @@ class OperationData: """ if isinstance(data, torch.Tensor): return data.shape + elif isinstance(data, torch.Size): + return None elif isinstance(data, (tuple, list)): data_type = type(data) return data_type([_infer_logical_shape(d) for d in data]) diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py index 6342feeee..adfd03d7d 100644 --- a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py +++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py @@ -82,7 +82,6 @@ class StrategiesConstructor: for node in self.nodes: strategies_vector = StrategiesVector(node) - print(node) if _check_no_strategy_for_node(node): no_strategy_node.append(node) pass diff --git a/colossalai/auto_parallel/tensor_shard/utils/__init__.py b/colossalai/auto_parallel/tensor_shard/utils/__init__.py index 63c48195d..b7fe5430b 100644 --- a/colossalai/auto_parallel/tensor_shard/utils/__init__.py +++ b/colossalai/auto_parallel/tensor_shard/utils/__init__.py @@ -7,6 +7,7 @@ from .broadcast import ( ) from .factory import generate_resharding_costs, generate_sharding_spec from .misc import check_sharding_spec_validity, ignore_sharding_exception, pytree_map +from .reshape import check_keep_sharding_status, detect_reshape_mapping, infer_output_dim_partition_dict from .sharding import ( enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding, @@ -19,5 +20,6 @@ __all__ = [ 'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape', 'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity' 'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', - 'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map' + 'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map', + 'detect_reshape_mapping', 'check_keep_sharding_status', 'infer_output_dim_partition_dict' ] diff --git a/colossalai/auto_parallel/tensor_shard/utils/reshape.py b/colossalai/auto_parallel/tensor_shard/utils/reshape.py new file mode 100644 index 000000000..8e02544f7 --- /dev/null +++ b/colossalai/auto_parallel/tensor_shard/utils/reshape.py @@ -0,0 +1,168 @@ +from enum import Enum +from typing import Dict, List, Tuple + +import torch + + +class PreviousStatus(Enum): + """ + This class shows the status of previous comparision. + """ + RESET = 0 + # ORIGIN means the dimension size of original tensor is larger in the previous comparision. + ORIGIN = 1 + # TGT means the dimension size of target tensor is larger in the previous comparision. + TGT = 2 + + +def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> Dict[Tuple[int], Tuple[int]]: + """ + This method is used to detect the reshape mapping between original tensor and target tensor. + + Returns: + reshape_mapping_dict: The dictionary shows how a tuple of origin dims(keys) mapping to the related + target dims(values) during reshaping operation. + Examples: + import torch + origin_shape = torch.Size([4, 4, 4]) + tgt_shape = torch.Size([2, 8, 2, 2]) + reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape) + print(reshape_mapping_dict) + Output: + {(2,): (3, 2), (1, 0): (1,), (0,): (0, 1)} + """ + + # reverse the shape object + origin_shape = list(origin_shape) + tgt_shape = list(tgt_shape) + origin_shape.reverse() + tgt_shape.reverse() + + # initialize arguments + reshape_mapping_dict = {} + origin_len = len(origin_shape) + tgt_len = len(tgt_shape) + origin_index = 0 + tgt_index = 0 + original_dimension_size = origin_shape[origin_index] + tgt_dimension_size = tgt_shape[tgt_index] + tgt_dims = [tgt_len - tgt_index - 1] + origin_dims = [origin_len - origin_index - 1] + previous_label = PreviousStatus.RESET + + while origin_index != len(origin_shape) or tgt_index != len(tgt_shape): + if original_dimension_size == tgt_dimension_size: + reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims) + origin_index += 1 + tgt_index += 1 + # the last step of loop should always end with condition + # so we need to manually skip the preparation for next step + # in the last step. + if origin_index == len(origin_shape): + continue + original_dimension_size = origin_shape[origin_index] + tgt_dimension_size = tgt_shape[tgt_index] + origin_dims = [origin_len - origin_index - 1] + tgt_dims = [tgt_len - tgt_index - 1] + previous_label = PreviousStatus.RESET + + elif original_dimension_size > tgt_dimension_size: + tgt_index += 1 + + if previous_label == PreviousStatus.TGT: + # if the target dimension size is larger in the previous comparision, which means + # the origin dimension size has already accumulated larger than target dimension size, so + # we need to offload the origin dims and tgt dims into the reshape_mapping_dict. + reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims) + original_dimension_size = original_dimension_size // tgt_dimension_size + origin_dims = [origin_len - origin_index - 1] + tgt_dimension_size = tgt_shape[tgt_index] + tgt_dims = [tgt_len - tgt_index - 1, tgt_len - tgt_index] + # reset the previous_label after offloading the origin dims and tgt dims + previous_label = PreviousStatus.RESET + else: + # accumulate the tgt_dimension_size until tgt_dimension_size larger than original_dimension_size + tgt_dimension_size *= tgt_shape[tgt_index] + tgt_dims.append(tgt_len - tgt_index - 1) + previous_label = PreviousStatus.ORIGIN + + else: + origin_index += 1 + + if previous_label == PreviousStatus.ORIGIN: + # if the origin element is larger in the previous comparision, which means + # the target element has already accumulated larger than origin element, so + # we need to offload the origin dims and tgt dims into the reshape_mapping_dict. + reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims) + tgt_dimension_size = tgt_dimension_size // original_dimension_size + tgt_dims = [tgt_len - tgt_index - 1] + original_dimension_size = origin_shape[origin_index] + origin_dims = [origin_len - origin_index - 1, origin_len - origin_index] + # reset the previous_label after offloading the origin dims and tgt dims + previous_label = PreviousStatus.RESET + else: + # accumulate the original_dimension_size until original_dimension_size larger than tgt_dimension_size + original_dimension_size *= origin_shape[origin_index] + origin_dims.append(origin_len - origin_index - 1) + previous_label = PreviousStatus.TGT + + return reshape_mapping_dict + + +def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]], + reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> bool: + """ + This method is used to check whether the reshape operation could implement without converting + the input to fully replicated status. + + Rule: + For a sharded dimension of input tensor, if it is not the minimum element of the input tuple, + the function will return false. + To illustrate this issue, there are two cases to analyse: + 1. no sharded dims in the input tuple: we could do the reshape operation safely just as the normal + operation without distributed tensor. + 2. sharded dims in the input tuple: the sharded dim must be the minimum element, then during shape + consistency process, torch.cat will be implemented on the sharded dim, and everything after the sharded + dim get recovered. + + Examples: + # the second dimension of the input has been sharded. + input_dim_partition_dict = {1: [1]} + origin_shape = torch.Size([8, 4, 2]) + tgt_shape = torch.Size([2, 4, 8]) + reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape) + # {(2, 1): (2,), (0,): (1, 0)} + # the sharded dim of input is 1, which is the minimum element of the tuple (2, 1), + # so we do not have to convert the input to fully replicated status. + print(check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict)) + + Output: + True + """ + sharded_dims = list(input_dim_partition_dict.keys()) + for input_dims in reshape_mapping_dict.keys(): + min_element = min(input_dims) + for dim in input_dims: + if dim in sharded_dims and dim is not min_element: + return False + return True + + +def infer_output_dim_partition_dict(input_dim_partition_dict: Dict[int, List[int]], + reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> Dict[Tuple[int], Tuple[int]]: + """ + This method is used to infer the output dim partition dict for a reshape operation, + given the input dim partition dict and reshape mapping dict. + """ + assert check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict), \ + 'we only infer output dim partition dict for the reshape operation could keep sharding spec.' + sharded_dims = list(input_dim_partition_dict.keys()) + output_dim_partition_dict = {} + for input_dims, output_dims in reshape_mapping_dict.items(): + for dim in input_dims: + if dim in sharded_dims: + output_dim_partition_dict[min(output_dims)] = input_dim_partition_dict[dim] + # we could break because input dims cannot contain two sharded dims, otherwise + # the keep sharding status check will fail. + break + return output_dim_partition_dict diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py new file mode 100644 index 000000000..fd219404e --- /dev/null +++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py @@ -0,0 +1,98 @@ +import torch +import torch.nn as nn + +from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler +from colossalai.auto_parallel.tensor_shard.node_handler.experimental import ViewHandler +from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector +from colossalai.device.device_mesh import DeviceMesh +from colossalai.fx import ColoGraphModule, ColoTracer +from colossalai.testing.pytest_wrapper import run_on_environment_flag + + +class ViewModel(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, input, other): + conv_node = nn.functional.conv2d(input, other) + reshape_node = conv_node.view(32, 4, 32, 32, 4) + return reshape_node + + +def test_view_handler(): + model = ViewModel() + tracer = ColoTracer() + # graph(): + # %input_1 : torch.Tensor [#users=1] = placeholder[target=input] + # %other : torch.Tensor [#users=1] = placeholder[target=other] + # %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {}) + # %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {}) + # return view + graph = tracer.trace(model, + meta_args={ + "input": torch.rand(8, 8, 66, 66).to('meta'), + "other": torch.rand(16, 8, 3, 3).to('meta'), + }) + gm = ColoGraphModule(model, graph) + physical_mesh_id = torch.arange(0, 4) + + mesh_shape = (2, 2) + device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) + conv_mod_node = list(graph.nodes)[2] + view_node = list(graph.nodes)[3] + view_strategies_vector = StrategiesVector(view_node) + conv_strategies_vector = StrategiesVector(conv_mod_node) + + # build handler + conv_handler = ConvFunctionHandler(node=conv_mod_node, + device_mesh=device_mesh, + strategies_vector=conv_strategies_vector) + conv_handler.register_strategy(compute_resharding_cost=False) + setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector) + view_handler = ViewHandler(node=view_node, device_mesh=device_mesh, strategies_vector=view_strategies_vector) + + view_handler.register_strategy(compute_resharding_cost=False) + + # check operation data mapping + mapping = view_handler.get_operation_data_mapping() + + for name, op_data in mapping.items(): + op_data: OperationData + # make sure they have valid values + assert op_data.data is not None + + assert mapping['input'].name == "conv2d" + assert mapping['input'].data.is_meta + assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64]) + assert mapping['input'].type == OperationDataType.ARG + assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64]) + + assert mapping['output'].name == "view" + assert mapping['output'].data.is_meta + assert mapping['output'].data.shape == torch.Size([32, 4, 32, 32, 4]) + assert mapping['output'].type == OperationDataType.OUTPUT + + # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node. + assert len(view_strategies_vector) == len(conv_strategies_vector) + strategy_name_list = [strategy.name for strategy in view_strategies_vector] + assert '[S0, S1, R, R] -> FULLY REPLICATED_0' in strategy_name_list + assert '[S1, S0, R, R] -> FULLY REPLICATED_1' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R]_2' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R]_3' in strategy_name_list + assert '[S0, R, R, R] -> [S0, R, R, R, R]_4' in strategy_name_list + assert '[S1, R, R, R] -> [S1, R, R, R, R]_5' in strategy_name_list + assert '[R, S1, R, R] -> FULLY REPLICATED_6' in strategy_name_list + assert '[R, S0, R, R] -> FULLY REPLICATED_7' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_8' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_9' in strategy_name_list + assert '[R, S0, R, R] -> FULLY REPLICATED_10' in strategy_name_list + assert '[R, S1, R, R] -> FULLY REPLICATED_11' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_12' in strategy_name_list + assert '[S01, R, R, R] -> [S01, R, R, R, R]_13' in strategy_name_list + assert '[R, R, R, R] -> [R, R, R, R, R]_14' in strategy_name_list + assert '[R, S01, R, R] -> FULLY REPLICATED_15' in strategy_name_list + + +if __name__ == '__main__': + test_view_handler()