From 14389931138d1397ca3e070f6b0cee6d685a2b6d Mon Sep 17 00:00:00 2001
From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
Date: Thu, 24 Nov 2022 11:34:41 +0800
Subject: [PATCH] [autoparallel] add experimental view handler (#2011)

* [autoparallel] add experimental view handler

* polish

* polish

* polish code

* rename variables
---
 .../node_handler/experimental/__init__.py     |   4 +
 .../experimental/view_generator.py            | 133 ++++++++++++++
 .../node_handler/experimental/view_handler.py |  51 ++++++
 .../tensor_shard/sharding_strategy.py         |   2 +
 .../solver/strategies_constructor.py          |   1 -
 .../tensor_shard/utils/__init__.py            |   4 +-
 .../tensor_shard/utils/reshape.py             | 168 ++++++++++++++++++
 .../test_node_handler/test_view_handler.py    |  98 ++++++++++
 8 files changed, 459 insertions(+), 2 deletions(-)
 create mode 100644 colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py
 create mode 100644 colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_generator.py
 create mode 100644 colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py
 create mode 100644 colossalai/auto_parallel/tensor_shard/utils/reshape.py
 create mode 100644 tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py

diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py
new file mode 100644
index 000000000..7f644c0e1
--- /dev/null
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/__init__.py
@@ -0,0 +1,4 @@
+from .view_generator import ViewGenerator
+from .view_handler import ViewHandler
+
+__all__ = ['ViewGenerator', 'ViewHandler']
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_generator.py
new file mode 100644
index 000000000..cdfa8b4eb
--- /dev/null
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_generator.py
@@ -0,0 +1,133 @@
+import copy
+from typing import List
+
+from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
+    CommAction,
+    CommType,
+    MemoryCost,
+    ShardingStrategy,
+    TrainCycleItem,
+)
+from colossalai.auto_parallel.tensor_shard.utils import (
+    check_keep_sharding_status,
+    detect_reshape_mapping,
+    infer_output_dim_partition_dict,
+)
+from colossalai.tensor.shape_consistency import CollectiveCommPattern
+from colossalai.tensor.sharding_spec import ShardingSpec
+
+__all__ = ['ViewGenerator']
+
+
+class ViewGenerator(FollowingStrategyGenerator):
+    """
+    ViewGenerator which deals with the sharding strategies of view op.
+    """
+
+    def validate(self) -> bool:
+        return super().validate()
+
+    def update_compute_cost(self, strategy: ShardingStrategy):
+        compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
+        strategy.compute_cost = compute_cost
+
+    def update_memory_cost(self, strategy: ShardingStrategy):
+        '''
+        Compute the memory cost per device with this specific strategy.
+        '''
+        forward_size_mapping = {
+            'input': self._compute_size_in_bytes(strategy, "input"),
+            'output': self._compute_size_in_bytes(strategy, "output")
+        }
+
+        backward_size_mapping = copy.deepcopy(forward_size_mapping)
+        backward_size_mapping.pop("output")
+        # compute fwd cost incurred
+        # fwd_cost = input + output
+        fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
+        fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
+        fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
+
+        # compute bwd cost incurred
+        # bwd_cost = input_grad
+        bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
+        bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
+        bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
+
+        # compute total cost
+        total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
+                                    parameter=fwd_parameter_cost + bwd_parameter_cost)
+        memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
+        strategy.memory_cost = memory_cost
+
+    def collate_strategies(self) -> List[ShardingStrategy]:
+        strategy_list = []
+        for index, strategy in enumerate(self.predecessor_node.strategies_vector):
+            dim_partition_dict_mapping = {}
+            communication_action_mapping = {}
+            input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
+
+            origin_shape = self.op_data['input'].data.shape
+            tgt_shape = self.op_data['tgt_shape'].data
+
+            reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
+
+            dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
+            keep_sharding_status = check_keep_sharding_status(dim_partition_dict_for_input, reshape_mapping_dict)
+
+            if keep_sharding_status:
+                dim_partition_dict_for_output = infer_output_dim_partition_dict(dim_partition_dict_for_input,
+                                                                                reshape_mapping_dict)
+            else:
+                dim_partition_dict_for_output = {}
+
+            dim_partition_dict_mapping = {
+                "input": dim_partition_dict_for_input,
+                "output": dim_partition_dict_for_output,
+            }
+            sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
+
+            # add index into name to pass the duplicated check
+            # we keep same strategies with different name for node merging, and it will not increase the searching space,
+            # because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
+            if keep_sharding_status:
+                name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
+            else:
+                name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}'
+
+                # add comm action for converting input to fully replicated
+                total_mesh_dim_list = []
+                for mesh_dim_list in dim_partition_dict_for_input.values():
+                    total_mesh_dim_list.extend(mesh_dim_list)
+                # if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
+                if len(total_mesh_dim_list) == 1:
+                    total_mesh_dim_list = total_mesh_dim_list[0]
+                    input_comm_action = self.get_communication_action(
+                        sharding_spec=sharding_spec_mapping["input"],
+                        communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
+                        logical_process_axis=total_mesh_dim_list,
+                        comm_type=CommType.BEFORE,
+                        arg_index=0)
+                    input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
+
+                elif len(total_mesh_dim_list) >= 2:
+                    source_spec = sharding_spec_mapping["input"]
+                    target_spec = ShardingSpec(device_mesh=self.device_mesh,
+                                               entire_shape=source_spec.entire_shape,
+                                               dim_partition_dict={})
+                    comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
+                    input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
+
+                else:
+                    input_comm_action = None
+
+                if input_comm_action is not None:
+                    communication_action_mapping["input"] = input_comm_action
+
+            strategy = self.get_sharding_strategy(name=name,
+                                                  sharding_spec_mapping=sharding_spec_mapping,
+                                                  communication_action_mapping=communication_action_mapping)
+            strategy_list.append(strategy)
+
+        return strategy_list
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py
new file mode 100644
index 000000000..bab4e0d76
--- /dev/null
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_handler.py
@@ -0,0 +1,51 @@
+from typing import Dict, List
+
+import torch
+
+from ...sharding_strategy import OperationData, OperationDataType
+from ..node_handler import NodeHandler
+from ..registry import operator_registry
+from ..strategy import StrategyGenerator
+from .view_generator import ViewGenerator
+
+__all__ = ['ViewHandler']
+
+
+@operator_registry.register(torch.Tensor.view)
+class ViewHandler(NodeHandler):
+    """
+    A ViewHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
+    """
+
+    def get_strategy_generator(self) -> List[StrategyGenerator]:
+        op_data_mapping = self.get_operation_data_mapping()
+        generators = []
+        generators.append(ViewGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
+        return generators
+
+    def get_operation_data_mapping(self) -> Dict[str, OperationData]:
+        # use transposed shape for strategies
+        # the strategies will be transformed back to its original shape in self.post_process
+
+        # check if the input operand is a parameter
+        if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
+            data_type = OperationDataType.PARAM
+        else:
+            data_type = OperationDataType.ARG
+
+        input_data = self.node.args[0]._meta_data
+        physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
+
+        target_shape = self.node._meta_data.shape
+        physical_shape_operand = OperationData(name='tgt_shape', type=OperationDataType.ARG, data=target_shape)
+
+        output_data = self.node._meta_data
+        physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
+
+        mapping = {
+            "input": physical_input_operand,
+            "tgt_shape": physical_shape_operand,
+            "output": physical_output_operand
+        }
+
+        return mapping
diff --git a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py
index efe484917..bbf4215d9 100644
--- a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py
+++ b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py
@@ -51,6 +51,8 @@ class OperationData:
                 """
                 if isinstance(data, torch.Tensor):
                     return data.shape
+                elif isinstance(data, torch.Size):
+                    return None
                 elif isinstance(data, (tuple, list)):
                     data_type = type(data)
                     return data_type([_infer_logical_shape(d) for d in data])
diff --git a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
index 6342feeee..adfd03d7d 100644
--- a/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
+++ b/colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
@@ -82,7 +82,6 @@ class StrategiesConstructor:
         for node in self.nodes:
             strategies_vector = StrategiesVector(node)
 
-            print(node)
             if _check_no_strategy_for_node(node):
                 no_strategy_node.append(node)
                 pass
diff --git a/colossalai/auto_parallel/tensor_shard/utils/__init__.py b/colossalai/auto_parallel/tensor_shard/utils/__init__.py
index 63c48195d..b7fe5430b 100644
--- a/colossalai/auto_parallel/tensor_shard/utils/__init__.py
+++ b/colossalai/auto_parallel/tensor_shard/utils/__init__.py
@@ -7,6 +7,7 @@ from .broadcast import (
 )
 from .factory import generate_resharding_costs, generate_sharding_spec
 from .misc import check_sharding_spec_validity, ignore_sharding_exception, pytree_map
+from .reshape import check_keep_sharding_status, detect_reshape_mapping, infer_output_dim_partition_dict
 from .sharding import (
     enumerate_all_possible_1d_sharding,
     enumerate_all_possible_2d_sharding,
@@ -19,5 +20,6 @@ __all__ = [
     'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
     'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
     'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
-    'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map'
+    'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map',
+    'detect_reshape_mapping', 'check_keep_sharding_status', 'infer_output_dim_partition_dict'
 ]
diff --git a/colossalai/auto_parallel/tensor_shard/utils/reshape.py b/colossalai/auto_parallel/tensor_shard/utils/reshape.py
new file mode 100644
index 000000000..8e02544f7
--- /dev/null
+++ b/colossalai/auto_parallel/tensor_shard/utils/reshape.py
@@ -0,0 +1,168 @@
+from enum import Enum
+from typing import Dict, List, Tuple
+
+import torch
+
+
+class PreviousStatus(Enum):
+    """
+    This class shows the status of previous comparision.
+    """
+    RESET = 0
+    # ORIGIN means the dimension size of original tensor is larger in the previous comparision.
+    ORIGIN = 1
+    # TGT means the dimension size of target tensor is larger in the previous comparision.
+    TGT = 2
+
+
+def detect_reshape_mapping(origin_shape: torch.Size, tgt_shape: torch.Size) -> Dict[Tuple[int], Tuple[int]]:
+    """
+    This method is used to detect the reshape mapping between original tensor and target tensor.
+
+    Returns:
+        reshape_mapping_dict: The dictionary shows how a tuple of origin dims(keys) mapping to the related
+        target dims(values) during reshaping operation.
+    Examples:
+        import torch
+        origin_shape = torch.Size([4, 4, 4])
+        tgt_shape = torch.Size([2, 8, 2, 2])
+        reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
+        print(reshape_mapping_dict)
+    Output:
+        {(2,): (3, 2), (1, 0): (1,), (0,): (0, 1)}
+    """
+
+    # reverse the shape object
+    origin_shape = list(origin_shape)
+    tgt_shape = list(tgt_shape)
+    origin_shape.reverse()
+    tgt_shape.reverse()
+
+    # initialize arguments
+    reshape_mapping_dict = {}
+    origin_len = len(origin_shape)
+    tgt_len = len(tgt_shape)
+    origin_index = 0
+    tgt_index = 0
+    original_dimension_size = origin_shape[origin_index]
+    tgt_dimension_size = tgt_shape[tgt_index]
+    tgt_dims = [tgt_len - tgt_index - 1]
+    origin_dims = [origin_len - origin_index - 1]
+    previous_label = PreviousStatus.RESET
+
+    while origin_index != len(origin_shape) or tgt_index != len(tgt_shape):
+        if original_dimension_size == tgt_dimension_size:
+            reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
+            origin_index += 1
+            tgt_index += 1
+            # the last step of loop should always end with condition
+            # so we need to manually skip the preparation for next step
+            # in the last step.
+            if origin_index == len(origin_shape):
+                continue
+            original_dimension_size = origin_shape[origin_index]
+            tgt_dimension_size = tgt_shape[tgt_index]
+            origin_dims = [origin_len - origin_index - 1]
+            tgt_dims = [tgt_len - tgt_index - 1]
+            previous_label = PreviousStatus.RESET
+
+        elif original_dimension_size > tgt_dimension_size:
+            tgt_index += 1
+
+            if previous_label == PreviousStatus.TGT:
+                # if the target dimension size is larger in the previous comparision, which means
+                # the origin dimension size has already accumulated larger than target dimension size, so
+                # we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
+                reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
+                original_dimension_size = original_dimension_size // tgt_dimension_size
+                origin_dims = [origin_len - origin_index - 1]
+                tgt_dimension_size = tgt_shape[tgt_index]
+                tgt_dims = [tgt_len - tgt_index - 1, tgt_len - tgt_index]
+                # reset the previous_label after offloading the origin dims and tgt dims
+                previous_label = PreviousStatus.RESET
+            else:
+                # accumulate the tgt_dimension_size until tgt_dimension_size larger than original_dimension_size
+                tgt_dimension_size *= tgt_shape[tgt_index]
+                tgt_dims.append(tgt_len - tgt_index - 1)
+                previous_label = PreviousStatus.ORIGIN
+
+        else:
+            origin_index += 1
+
+            if previous_label == PreviousStatus.ORIGIN:
+                # if the origin element is larger in the previous comparision, which means
+                # the target element has already accumulated larger than origin element, so
+                # we need to offload the origin dims and tgt dims into the reshape_mapping_dict.
+                reshape_mapping_dict[tuple(origin_dims)] = tuple(tgt_dims)
+                tgt_dimension_size = tgt_dimension_size // original_dimension_size
+                tgt_dims = [tgt_len - tgt_index - 1]
+                original_dimension_size = origin_shape[origin_index]
+                origin_dims = [origin_len - origin_index - 1, origin_len - origin_index]
+                # reset the previous_label after offloading the origin dims and tgt dims
+                previous_label = PreviousStatus.RESET
+            else:
+                # accumulate the original_dimension_size until original_dimension_size larger than tgt_dimension_size
+                original_dimension_size *= origin_shape[origin_index]
+                origin_dims.append(origin_len - origin_index - 1)
+                previous_label = PreviousStatus.TGT
+
+    return reshape_mapping_dict
+
+
+def check_keep_sharding_status(input_dim_partition_dict: Dict[int, List[int]],
+                               reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> bool:
+    """
+    This method is used to check whether the reshape operation could implement without converting
+    the input to fully replicated status.
+
+    Rule:
+        For a sharded dimension of input tensor, if it is not the minimum element of the input tuple,
+        the function will return false.
+        To illustrate this issue, there are two cases to analyse:
+        1. no sharded dims in the input tuple: we could do the reshape operation safely just as the normal
+        operation without distributed tensor.
+        2. sharded dims in the input tuple: the sharded dim must be the minimum element, then during shape
+        consistency process, torch.cat will be implemented on the sharded dim, and everything after the sharded
+        dim get recovered.
+
+    Examples:
+        # the second dimension of the input has been sharded.
+        input_dim_partition_dict = {1: [1]}
+        origin_shape = torch.Size([8, 4, 2])
+        tgt_shape = torch.Size([2, 4, 8])
+        reshape_mapping_dict = detect_reshape_mapping(origin_shape, tgt_shape)
+        # {(2, 1): (2,), (0,): (1, 0)}
+        # the sharded dim of input is 1, which is the minimum element of the tuple (2, 1),
+        # so we do not have to convert the input to fully replicated status.
+        print(check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict))
+
+    Output:
+        True
+    """
+    sharded_dims = list(input_dim_partition_dict.keys())
+    for input_dims in reshape_mapping_dict.keys():
+        min_element = min(input_dims)
+        for dim in input_dims:
+            if dim in sharded_dims and dim is not min_element:
+                return False
+    return True
+
+
+def infer_output_dim_partition_dict(input_dim_partition_dict: Dict[int, List[int]],
+                                    reshape_mapping_dict: Dict[Tuple[int], Tuple[int]]) -> Dict[Tuple[int], Tuple[int]]:
+    """
+    This method is used to infer the output dim partition dict for a reshape operation,
+    given the input dim partition dict and reshape mapping dict.
+    """
+    assert check_keep_sharding_status(input_dim_partition_dict, reshape_mapping_dict), \
+        'we only infer output dim partition dict for the reshape operation could keep sharding spec.'
+    sharded_dims = list(input_dim_partition_dict.keys())
+    output_dim_partition_dict = {}
+    for input_dims, output_dims in reshape_mapping_dict.items():
+        for dim in input_dims:
+            if dim in sharded_dims:
+                output_dim_partition_dict[min(output_dims)] = input_dim_partition_dict[dim]
+                # we could break because input dims cannot contain two sharded dims, otherwise
+                # the keep sharding status check will fail.
+                break
+    return output_dim_partition_dict
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py
new file mode 100644
index 000000000..fd219404e
--- /dev/null
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py
@@ -0,0 +1,98 @@
+import torch
+import torch.nn as nn
+
+from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler
+from colossalai.auto_parallel.tensor_shard.node_handler.experimental import ViewHandler
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
+from colossalai.device.device_mesh import DeviceMesh
+from colossalai.fx import ColoGraphModule, ColoTracer
+from colossalai.testing.pytest_wrapper import run_on_environment_flag
+
+
+class ViewModel(nn.Module):
+
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, input, other):
+        conv_node = nn.functional.conv2d(input, other)
+        reshape_node = conv_node.view(32, 4, 32, 32, 4)
+        return reshape_node
+
+
+def test_view_handler():
+    model = ViewModel()
+    tracer = ColoTracer()
+    # graph():
+    #     %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
+    #     %other : torch.Tensor [#users=1] = placeholder[target=other]
+    #     %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
+    #     %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {})
+    #     return view
+    graph = tracer.trace(model,
+                         meta_args={
+                             "input": torch.rand(8, 8, 66, 66).to('meta'),
+                             "other": torch.rand(16, 8, 3, 3).to('meta'),
+                         })
+    gm = ColoGraphModule(model, graph)
+    physical_mesh_id = torch.arange(0, 4)
+
+    mesh_shape = (2, 2)
+    device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
+    conv_mod_node = list(graph.nodes)[2]
+    view_node = list(graph.nodes)[3]
+    view_strategies_vector = StrategiesVector(view_node)
+    conv_strategies_vector = StrategiesVector(conv_mod_node)
+
+    # build handler
+    conv_handler = ConvFunctionHandler(node=conv_mod_node,
+                                       device_mesh=device_mesh,
+                                       strategies_vector=conv_strategies_vector)
+    conv_handler.register_strategy(compute_resharding_cost=False)
+    setattr(conv_mod_node, 'strategies_vector', conv_strategies_vector)
+    view_handler = ViewHandler(node=view_node, device_mesh=device_mesh, strategies_vector=view_strategies_vector)
+
+    view_handler.register_strategy(compute_resharding_cost=False)
+
+    # check operation data mapping
+    mapping = view_handler.get_operation_data_mapping()
+
+    for name, op_data in mapping.items():
+        op_data: OperationData
+        # make sure they have valid values
+        assert op_data.data is not None
+
+    assert mapping['input'].name == "conv2d"
+    assert mapping['input'].data.is_meta
+    assert mapping['input'].data.shape == torch.Size([8, 16, 64, 64])
+    assert mapping['input'].type == OperationDataType.ARG
+    assert mapping['input'].logical_shape == torch.Size([8, 16, 64, 64])
+
+    assert mapping['output'].name == "view"
+    assert mapping['output'].data.is_meta
+    assert mapping['output'].data.shape == torch.Size([32, 4, 32, 32, 4])
+    assert mapping['output'].type == OperationDataType.OUTPUT
+
+    # reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.
+    assert len(view_strategies_vector) == len(conv_strategies_vector)
+    strategy_name_list = [strategy.name for strategy in view_strategies_vector]
+    assert '[S0, S1, R, R] -> FULLY REPLICATED_0' in strategy_name_list
+    assert '[S1, S0, R, R] -> FULLY REPLICATED_1' in strategy_name_list
+    assert '[S0, R, R, R] -> [S0, R, R, R, R]_2' in strategy_name_list
+    assert '[S1, R, R, R] -> [S1, R, R, R, R]_3' in strategy_name_list
+    assert '[S0, R, R, R] -> [S0, R, R, R, R]_4' in strategy_name_list
+    assert '[S1, R, R, R] -> [S1, R, R, R, R]_5' in strategy_name_list
+    assert '[R, S1, R, R] -> FULLY REPLICATED_6' in strategy_name_list
+    assert '[R, S0, R, R] -> FULLY REPLICATED_7' in strategy_name_list
+    assert '[R, R, R, R] -> [R, R, R, R, R]_8' in strategy_name_list
+    assert '[R, R, R, R] -> [R, R, R, R, R]_9' in strategy_name_list
+    assert '[R, S0, R, R] -> FULLY REPLICATED_10' in strategy_name_list
+    assert '[R, S1, R, R] -> FULLY REPLICATED_11' in strategy_name_list
+    assert '[R, R, R, R] -> [R, R, R, R, R]_12' in strategy_name_list
+    assert '[S01, R, R, R] -> [S01, R, R, R, R]_13' in strategy_name_list
+    assert '[R, R, R, R] -> [R, R, R, R, R]_14' in strategy_name_list
+    assert '[R, S01, R, R] -> FULLY REPLICATED_15' in strategy_name_list
+
+
+if __name__ == '__main__':
+    test_view_handler()