From b4cc59b61e4f8921eb2a06417279cddc3c5b6e33 Mon Sep 17 00:00:00 2001
From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com>
Date: Thu, 27 Oct 2022 10:42:54 +0800
Subject: [PATCH] [autoparallel] add numerical test for node strategies (#1760)

* [autoparallel] add numerical test for node strategies

* polish code

* polish code
---
 .../passes/runtime_apply_pass.py              |  52 ++++++--
 .../passes/runtime_preparation_pass.py        |   1 +
 .../strategy/conv_strategy_generator.py       |  24 ++--
 .../strategy/strategy_generator.py            |   6 +-
 .../tensor_shard/sharding_strategy.py         |   1 +
 colossalai/device/device_mesh.py              |  19 ++-
 colossalai/tensor/shape_consistency.py        |   9 ++
 colossalai/tensor/sharding_spec.py            |  13 +-
 .../test_node_handler/test_conv_handler.py    |  92 ++++++++++---
 .../test_node_handler/utils.py                | 126 ++++++++++++++++++
 10 files changed, 283 insertions(+), 60 deletions(-)
 create mode 100644 tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py

diff --git a/colossalai/auto_parallel/passes/runtime_apply_pass.py b/colossalai/auto_parallel/passes/runtime_apply_pass.py
index 09f123665..cc2466273 100644
--- a/colossalai/auto_parallel/passes/runtime_apply_pass.py
+++ b/colossalai/auto_parallel/passes/runtime_apply_pass.py
@@ -24,7 +24,6 @@ def runtime_apply(node: Node, origin_dict: Dict, input_dict: Dict, node_index: i
     """
     origin_sharding_spec = origin_dict[node_index]
     target_sharding_spec = input_dict[node_index][user_node_index]
-
     return shape_consistency_manager.apply_for_autoparallel_runtime(node, origin_sharding_spec, target_sharding_spec)
 
 
@@ -81,18 +80,24 @@ def _shape_consistency_apply(gm: torch.fx.GraphModule):
         if not hasattr(node, 'best_strategy') or node.op == 'output':
             continue
 
-        for user_node in node.strategies_vector.successor_nodes:
-            user_node_index = user_node.strategies_vector.predecessor_nodes.index(node)
+        for user_node_index, user_node in enumerate(node.strategies_vector.successor_nodes):
             with mod_graph.inserting_before(user_node):
                 shape_consistency_node = mod_graph.create_node('call_function',
                                                                runtime_apply,
                                                                args=(node, origin_dict_node, input_dict_node,
                                                                      node_to_index_dict[node], user_node_index))
-
-            origin_index_args = user_node.args.index(node)
             new_args = list(user_node.args)
-            new_args[origin_index_args] = shape_consistency_node
-            user_node.args = new_args
+            new_kwargs = dict(user_node.kwargs)
+            # the origin node may be a positional argument or key word argument of user node
+            if node in new_args:
+                # substitute the origin node with shape_consistency_node
+                origin_index_args = new_args.index(node)
+                new_args[origin_index_args] = shape_consistency_node
+                user_node.args = new_args
+            elif str(node) in new_kwargs:
+                # substitute the origin node with shape_consistency_node
+                new_kwargs[str(node)] = shape_consistency_node
+                user_node.kwargs = new_kwargs
 
     return gm
 
@@ -112,18 +117,31 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
 
         comm_actions = node.best_strategy.communication_actions
         for op_data, comm_action in comm_actions.items():
-            comm_object = node.args[comm_action.arg_index]
+
             if op_data.type == OperationDataType.PARAM:
                 continue
             if comm_action.comm_type == CommType.BEFORE:
+                if comm_action.key_for_kwarg is not None:
+                    comm_object = node.kwargs[comm_action.key_for_kwarg]
+                else:
+                    comm_object = node.args[comm_action.arg_index]
                 with mod_graph.inserting_before(node):
                     comm_spec_apply_node = mod_graph.create_node('call_function',
                                                                  runtime_comm_spec_apply,
                                                                  args=(comm_object, comm_actions_dict_node,
                                                                        node_to_index_dict[node], op_data.name))
-                new_args = list(node.args)
-                new_args[comm_action.arg_index] = comm_spec_apply_node
-                node.args = new_args
+                # the origin node may be a positional argument or key word argument of user node
+                if comm_action.key_for_kwarg is not None:
+                    # substitute the origin node with comm_spec_apply_node
+                    new_kwargs = dict(node.kwargs)
+                    new_kwargs[comm_action.key_for_kwarg] = comm_spec_apply_node
+                    node.kwargs = new_kwargs
+                else:
+                    # substitute the origin node with comm_spec_apply_node
+                    new_args = list(node.args)
+                    new_args[comm_action.arg_index] = comm_spec_apply_node
+                    node.args = new_args
+
             elif comm_action.comm_type == CommType.AFTER:
                 with mod_graph.inserting_after(node):
                     comm_spec_apply_node = mod_graph.create_node('call_function',
@@ -135,8 +153,16 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
                     if user == comm_spec_apply_node:
                         continue
                     new_args = list(user.args)
-                    new_args[new_args.index(node)] = comm_spec_apply_node
-                    user.args = tuple(new_args)
+                    new_kwargs = dict(user.kwargs)
+                    # the origin node may be a positional argument or key word argument of user node
+                    if node in new_args:
+                        # substitute the origin node with comm_spec_apply_node
+                        new_args[new_args.index(node)] = comm_spec_apply_node
+                        user.args = tuple(new_args)
+                    elif str(node) in new_kwargs:
+                        # substitute the origin node with comm_spec_apply_node
+                        new_kwargs[str(node)] = comm_spec_apply_node
+                        user.kwargs = new_kwargs
 
     return gm
 
diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py
index 796a95ee4..00268e3f5 100644
--- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py
+++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py
@@ -77,6 +77,7 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh):
                 if target_sharding_spec.dim_partition_dict != {}:
                     origin_sharding_spec = ShardingSpec(device_mesh, param.shape, {})
                     setattr(param, 'sharding_spec', origin_sharding_spec)
+                    # TODO: build a ColoParamter class to manager the distributed parameters
                     param_sharded = torch.nn.Parameter(
                         shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
                                                                                  target_sharding_spec).detach().clone())
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py
index 83476e4fe..f7e4543f8 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py
@@ -4,7 +4,6 @@ import warnings
 from functools import reduce
 from typing import List
 
-
 from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
     CommAction,
     CommType,
@@ -12,10 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
     ShardingStrategy,
     TrainCycleItem,
 )
-
-from colossalai.auto_parallel.tensor_shard.utils import \
-    ignore_sharding_exception
-
+from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
 from colossalai.tensor.shape_consistency import CollectiveCommPattern
 
 from .strategy_generator import StrategyGenerator
@@ -135,7 +131,8 @@ class ConvStrategyGenerator(StrategyGenerator):
             sharding_spec=sharding_spec_mapping["input"],
             communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
             logical_process_axis=mesh_dim_1,
-            comm_type=CommType.BEFORE)
+            comm_type=CommType.BEFORE,
+            arg_index=0)
         communication_action_mapping = {"input": input_comm_action}
 
         if self.is_param("other"):
@@ -223,8 +220,7 @@ class ConvStrategyGenerator(StrategyGenerator):
             sharding_spec_mapping["output"],
             communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
             logical_process_axis=mesh_dim_1,
-            comm_type=CommType.AFTER,
-            arg_index=0)
+            comm_type=CommType.AFTER)
 
         communication_action_mapping = {"output": output_comm_action}
 
@@ -277,8 +273,7 @@ class ConvStrategyGenerator(StrategyGenerator):
             sharding_spec_mapping["output"],
             communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
             logical_process_axis=mesh_dim_0,
-            comm_type=CommType.AFTER,
-            arg_index=0)
+            comm_type=CommType.AFTER)
         input_comm_action = self.get_communication_action(
             sharding_spec_mapping["input"],
             communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
@@ -316,8 +311,7 @@ class ConvStrategyGenerator(StrategyGenerator):
             sharding_spec_mapping["output"],
             communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
             logical_process_axis=mesh_dim_0,
-            comm_type=CommType.AFTER,
-            arg_index=0)
+            comm_type=CommType.AFTER)
 
         communication_action_mapping = {"output": output_comm_action}
 
@@ -351,7 +345,8 @@ class ConvStrategyGenerator(StrategyGenerator):
             sharding_spec_mapping["input"],
             communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
             logical_process_axis=mesh_dim_0,
-            comm_type=CommType.BEFORE)
+            comm_type=CommType.BEFORE,
+            arg_index=0)
 
         communication_action_mapping = {"input": input_comm_action}
 
@@ -441,8 +436,7 @@ class ConvStrategyGenerator(StrategyGenerator):
             sharding_spec_mapping["output"],
             communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
             logical_process_axis=[mesh_dim_0, mesh_dim_1],
-            comm_type=CommType.AFTER,
-            arg_index=0)
+            comm_type=CommType.AFTER)
 
         communication_action_mapping = {"output": output_comm_action}
 
diff --git a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
index 8f57ee6a0..b3903b9d7 100644
--- a/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
+++ b/colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
@@ -109,7 +109,8 @@ class StrategyGenerator(ABC):
                                  communication_pattern: CollectiveCommPattern,
                                  logical_process_axis: Union[int, List[int]],
                                  comm_type: CommType,
-                                 arg_index: int = -1) -> CommAction:
+                                 arg_index: int = -1,
+                                 key_for_kwarg: any = None) -> CommAction:
         """
         A factory method to produce a CommAction object.
         """
@@ -117,7 +118,8 @@ class StrategyGenerator(ABC):
                                                                 communication_pattern=communication_pattern,
                                                                 logical_process_axis=logical_process_axis),
                           comm_type=comm_type,
-                          arg_index=arg_index)
+                          arg_index=arg_index,
+                          key_for_kwarg=key_for_kwarg)
 
     def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
         """
diff --git a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py
index 8dbb0014b..334fb10d7 100644
--- a/colossalai/auto_parallel/tensor_shard/sharding_strategy.py
+++ b/colossalai/auto_parallel/tensor_shard/sharding_strategy.py
@@ -115,6 +115,7 @@ class CommAction:
     comm_spec: CommSpec = None
     comm_type: CommType = None
     arg_index: int = -1
+    key_for_kwarg: any = None
 
 
 @dataclass
diff --git a/colossalai/device/device_mesh.py b/colossalai/device/device_mesh.py
index df010e7d7..403bbe4ae 100644
--- a/colossalai/device/device_mesh.py
+++ b/colossalai/device/device_mesh.py
@@ -1,5 +1,6 @@
-from functools import reduce
 import operator
+from functools import reduce
+
 import torch
 import torch.distributed as dist
 
@@ -11,7 +12,7 @@ class DeviceMesh:
     can be viewed as a 1x16 or a 4x4 logical mesh). Each mesh dimension has its
     own latency and bandwidth. We use alpha-beta model to model the
     communication cost.
-    
+
     Arguments:
         physical_mesh_id (torch.Tensor): physical view of the devices in global rank.
         mesh_shape (torch.Size): shape of logical view.
@@ -64,6 +65,18 @@ class DeviceMesh:
     def logical_mesh_id(self):
         return self._logical_mesh_id
 
+    def __deepcopy__(self, memo):
+        cls = self.__class__
+        result = cls.__new__(cls)
+        memo[id(self)] = result
+        for k, v in self.__dict__.items():
+            if k != 'process_groups_dict':
+                setattr(result, k, __import__("copy").deepcopy(v, memo))
+            else:
+                setattr(result, k, v)
+
+        return result
+
     def flatten(self):
         """
         Flatten the logical mesh into an effective 1d logical mesh,
@@ -90,7 +103,7 @@ class DeviceMesh:
     def create_process_groups_for_logical_mesh(self):
         '''
         This method is used to initialize the logical process groups which will be used in communications
-        among logical device mesh. 
+        among logical device mesh.
         Note: if init_process_group set to False, you have to call this method manually. Otherwise,
         the communication related function, such as ShapeConsistencyManager.apply will raise errors.
         '''
diff --git a/colossalai/tensor/shape_consistency.py b/colossalai/tensor/shape_consistency.py
index d96040817..4ec5ad9e9 100644
--- a/colossalai/tensor/shape_consistency.py
+++ b/colossalai/tensor/shape_consistency.py
@@ -28,6 +28,15 @@ class ShapeConsistencyOptions:
     pass
 
 
+def to_global(distributed_tensor: torch.Tensor, sharding_spec: ShardingSpec):
+    shape_consistency_manager = ShapeConsistencyManager()
+    global_sharding_spec = ShardingSpec(sharding_spec.device_mesh, sharding_spec.entire_shape, {})
+    with torch.no_grad():
+        global_tensor = shape_consistency_manager.apply_for_autoparallel_runtime(distributed_tensor, sharding_spec,
+                                                                                 global_sharding_spec)
+    return global_tensor
+
+
 def set_shape_consistency_options(options: ShapeConsistencyOptions):
     """
     Configure the shape consistency manager via function call.
diff --git a/colossalai/tensor/sharding_spec.py b/colossalai/tensor/sharding_spec.py
index fababb6e7..37d397885 100644
--- a/colossalai/tensor/sharding_spec.py
+++ b/colossalai/tensor/sharding_spec.py
@@ -6,7 +6,6 @@ from functools import reduce
 import torch
 
 from colossalai.device.device_mesh import DeviceMesh
-from colossalai.tensor.utils import (all_gather_simulator, all_to_all_simulator, shard_simulator)
 
 __all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
 
@@ -23,7 +22,7 @@ class _DimSpec:
     This class is used internally in ShardingSpec.
 
     Argument:
-        shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type. 
+        shard_list(List[int]): if shard_list is None, the dim spec will be 'R' type.
             Otherwise, the element in shard_list means the data will be sharded in that dimension.
     '''
 
@@ -62,7 +61,7 @@ class _DimSpec:
 
     def build_difference_2d_dict(self):
         '''
-        Build a difference maping for 2D device mesh case. It will be used to 
+        Build a difference maping for 2D device mesh case. It will be used to
         compute the difference between DimSpec pairs.
         '''
 
@@ -159,9 +158,9 @@ class ShardingNotDivisibleError(ShardingSpecException):
 class ShardingSpec:
     '''
     Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong
-    to, the entire shape of the tensor before sharded, and the sharding sequence looks like 
+    to, the entire shape of the tensor before sharded, and the sharding sequence looks like
     [R, R, S0, S1].
-    
+
     Argument:
         device_mesh(DeviceMesh): A logical view of a physical mesh.
         entire_shape(torch.Size): The entire shape of tensor before sharded.
@@ -260,10 +259,10 @@ class ShardingSpec:
             #     device_mesh_shape: (4, 4)
             sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
             print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))
-        
+
         Output:
             25
-        
+
         Argument:
             other(ShardingSpec): The ShardingSpec to compared with.
 
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py
index 97025729c..dc86712f6 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py
@@ -1,27 +1,44 @@
+from functools import partial
+
+import pytest
 import torch
+import torch.multiprocessing as mp
 import torch.nn as nn
 
 from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import ConvFunctionHandler, ConvModuleHandler
 from colossalai.auto_parallel.tensor_shard.sharding_strategy import OperationData, OperationDataType, StrategiesVector
 from colossalai.device.device_mesh import DeviceMesh
 from colossalai.fx import ColoGraphModule, ColoTracer
-from colossalai.testing import parameterize
+from colossalai.initialize import launch
+from colossalai.logging import disable_existing_loggers
+from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
+from colossalai.testing.pytest_wrapper import run_on_environment_flag
+from colossalai.utils import free_port
+from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
 
 
-@parameterize('bias', [True, False])
-def test_conv_module_handler(bias):
-    model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias).to('meta'))
-    tracer = ColoTracer()
+def check_conv_module_handler(rank, bias, world_size, port):
+    disable_existing_loggers()
+    launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+    model = nn.Sequential(nn.Conv2d(4, 16, 3, padding=1, bias=bias)).cuda()
     # graph():
     #     %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
     #     %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
     #     return _0
+    input = torch.rand(4, 4, 64, 64).cuda()
+
+    physical_mesh_id = torch.arange(0, 4)
+    mesh_shape = (2, 2)
+    device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
+
+    # index of conv node in this graph
+    node_index = 1
+    # total number of conv strategies
+    strategy_number = 16
+    numerical_test_for_node_strategy(model, device_mesh, node_index, strategy_number, [input], ['input'])
+    tracer = ColoTracer()
     graph = tracer.trace(model, meta_args={"input": torch.rand(4, 4, 64, 64).to('meta')})
     gm = ColoGraphModule(model, graph)
-    physical_mesh_id = torch.arange(0, 4)
-
-    mesh_shape = (2, 2)
-    device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
     conv_mod_node = list(graph.nodes)[1]
     strategies_vector = StrategiesVector(conv_mod_node)
 
@@ -38,26 +55,26 @@ def test_conv_module_handler(bias):
         assert op_data.data is not None
 
     assert mapping['input'].name == "input_1"
-    assert mapping['input'].data.is_meta
+    # assert mapping['input'].data.is_meta
     assert mapping['input'].data.shape == torch.Size([4, 4, 64, 64])
     assert mapping['input'].type == OperationDataType.ARG
     assert mapping['input'].logical_shape == torch.Size([4, 4, 64, 64])
 
     assert mapping['other'].name == "weight"
-    assert mapping['other'].data.is_meta
+    # assert mapping['other'].data.is_meta
     assert mapping['other'].data.shape == torch.Size([16, 4, 3, 3])
     assert mapping['other'].type == OperationDataType.PARAM
     assert mapping['other'].logical_shape == torch.Size([4, 16, 3, 3])
 
     if bias:
         assert mapping['bias'].name == "bias"
-        assert mapping['bias'].data.is_meta
+        # assert mapping['bias'].data.is_meta
         assert mapping['bias'].data.shape == torch.Size([16])
         assert mapping['bias'].type == OperationDataType.PARAM
         assert mapping['bias'].logical_shape == torch.Size([16])
 
     assert mapping['output'].name == "_0"
-    assert mapping['output'].data.is_meta
+    # assert mapping['output'].data.is_meta
     assert mapping['output'].data.shape == torch.Size([4, 16, 64, 64])
     assert mapping['output'].type == OperationDataType.OUTPUT
 
@@ -129,9 +146,28 @@ class ConvModel(nn.Module):
         return x
 
 
-@parameterize('bias', [True, False])
-def test_conv_function_handler(bias):
-    model = ConvModel()
+def check_conv_function_handler(rank, bias, world_size, port):
+    disable_existing_loggers()
+    launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
+    model = ConvModel().cuda()
+    physical_mesh_id = torch.arange(0, 4)
+    mesh_shape = (2, 2)
+    device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
+    input = torch.rand(4, 4, 64, 64).cuda()
+    others = torch.rand(16, 4, 3, 3).cuda()
+    input_args = [input, others]
+    meta_arg_names = ['input', 'others']
+    input_kwargs = {}
+    # total number of conv strategies
+    strategy_number = 16
+    node_index = 2
+    if bias:
+        bias_tensor = torch.rand(16).cuda()
+        input_kwargs['bias'] = bias_tensor
+        node_index += 1
+    numerical_test_for_node_strategy(model, device_mesh, node_index, strategy_number, input_args, meta_arg_names,
+                                     input_kwargs)
+
     tracer = ColoTracer()
     # graph():
     #     %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
@@ -143,10 +179,6 @@ def test_conv_function_handler(bias):
         meta_args['bias'] = torch.rand(16).to('meta')
     graph = tracer.trace(model, meta_args=meta_args)
     gm = ColoGraphModule(model, graph)
-    physical_mesh_id = torch.arange(0, 4)
-
-    mesh_shape = (2, 2)
-    device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
 
     if bias:
         conv_mod_node = list(graph.nodes)[3]
@@ -248,6 +280,26 @@ def test_conv_function_handler(bias):
             assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[1]
 
 
+@run_on_environment_flag(name='AUTO_PARALLEL')
+@pytest.mark.dist
+@parameterize('bias', [True, False])
+@rerun_if_address_is_in_use()
+def test_conv_module_handler(bias):
+    world_size = 4
+    run_func = partial(check_conv_module_handler, bias=bias, world_size=world_size, port=free_port())
+    mp.spawn(run_func, nprocs=world_size)
+
+
+@run_on_environment_flag(name='AUTO_PARALLEL')
+@pytest.mark.dist
+@parameterize('bias', [True, False])
+@rerun_if_address_is_in_use()
+def test_conv_function_handler(bias):
+    world_size = 4
+    run_func = partial(check_conv_function_handler, bias=bias, world_size=world_size, port=free_port())
+    mp.spawn(run_func, nprocs=world_size)
+
+
 if __name__ == '__main__':
     test_conv_module_handler()
     test_conv_function_handler()
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
new file mode 100644
index 000000000..47ee6be79
--- /dev/null
+++ b/tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
@@ -0,0 +1,126 @@
+import copy
+from typing import Dict, List
+
+import torch
+from torch.fx import GraphModule
+
+from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
+from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
+from colossalai.auto_parallel.tensor_shard.solver import SolverOptions, StrategiesConstructor
+from colossalai.device.device_mesh import DeviceMesh
+from colossalai.fx.tracer.tracer import ColoTracer
+from colossalai.tensor.shape_consistency import to_global
+from colossalai.testing.comparison import assert_close
+
+
+def _build_model_to_compare(model: torch.nn.Module, input_args: List[torch.Tensor],
+                            input_kwargs: Dict[str, torch.Tensor], grad_dict: Dict[any, torch.Tensor]):
+
+    model_to_compare = copy.deepcopy(model)
+    args_to_compare = []
+    kwargs_to_compare = {}
+    for arg_index, input_tensor in enumerate(input_args):
+
+        def wrapper(param, index):
+
+            def hook_fn(grad):
+                grad_dict[index] = grad
+
+            param.register_hook(hook_fn)
+
+        arg_to_compare = copy.deepcopy(input_tensor)
+        arg_to_compare.requires_grad = True
+        wrapper(arg_to_compare, arg_index)
+        # arg_to_compare.register_hook(hook_fn)
+        args_to_compare.append(arg_to_compare)
+
+    for name, input_kwarg in input_kwargs.items():
+
+        def wrapper(param, name):
+
+            def hook_fn(grad):
+                grad_dict[name] = grad
+
+            param.register_hook(hook_fn)
+
+        kwarg_to_compare = copy.deepcopy(input_kwarg)
+        kwarg_to_compare.requires_grad = True
+        wrapper(kwarg_to_compare, name)
+        kwargs_to_compare[name] = kwarg_to_compare
+
+    return model_to_compare, args_to_compare, kwargs_to_compare
+
+
+def numerical_test_for_node_strategy(model: torch.nn.Module,
+                                     device_mesh: DeviceMesh,
+                                     node_index: int,
+                                     strategy_number: int,
+                                     input_args: List[torch.Tensor],
+                                     meta_arg_names: List[str],
+                                     input_kwargs: Dict[str, torch.Tensor] = {}):
+    for strategy_index in range(strategy_number):
+        print(f'#strategy_index: {strategy_index}')
+        # We need to copy the model to avoid do backward more than once in same graph
+        grad_to_compare_dict = {}
+        grad_to_shard_dict = {}
+        model_to_compare, args_to_compare, kwargs_to_compare = _build_model_to_compare(
+            model, input_args, input_kwargs, grad_to_compare_dict)
+        model_to_shard, args_to_shard, kwargs_to_shard = _build_model_to_compare(model, input_args, input_kwargs,
+                                                                                 grad_to_shard_dict)
+
+        zero_tensor = torch.Tensor(0).cuda()
+
+        tracer = ColoTracer()
+        input_sample = {}
+        for input_arg, meta_arg_name in zip(input_args, meta_arg_names):
+            input_sample[meta_arg_name] = torch.rand(input_arg.shape).to('meta')
+        for meta_kwarg_name, input_kwarg in input_kwargs.items():
+            input_sample[meta_kwarg_name] = torch.rand(input_kwarg.shape).to('meta')
+        graph = tracer.trace(root=model_to_shard, meta_args=input_sample)
+        gm = GraphModule(model_to_shard, graph, model_to_shard.__class__.__name__)
+        solver_options = SolverOptions(fast=True)
+        strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
+        strategies_constructor.build_strategies_and_cost()
+        target_node = list(graph.nodes)[node_index]
+
+        # solution construction
+        solution_len = len(strategies_constructor.leaf_strategies)
+        solution = [0] * solution_len
+        solution[node_index] = strategy_index
+        gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
+            gm, solution, device_mesh)
+        gm = runtime_apply_pass(gm)
+        gm.recompile()
+
+        # forward result compare
+        output = gm(*args_to_shard,
+                    sharding_spec_convert_dict=sharding_spec_dict,
+                    origin_node_sharding_spec_dict=origin_spec_dict,
+                    comm_actions_dict=comm_actions_dict,
+                    **kwargs_to_shard)
+        # except:
+        #     print(gm)
+        output_to_compare = model_to_compare(*args_to_compare, **kwargs_to_compare)
+        assert_close((output - output_to_compare).sum(), zero_tensor)
+
+        # backward result compare
+        loss = output.sum()
+        loss_to_compare = output_to_compare.sum()
+        loss.backward()
+        loss_to_compare.backward()
+        for key in grad_to_shard_dict.keys():
+            grad_to_shard = grad_to_shard_dict[key]
+            grad_to_compare = grad_to_compare_dict[key]
+            assert_close((grad_to_shard - grad_to_compare).sum(), zero_tensor)
+
+        # extract the strategy used in this iter
+        strategy_in_use = target_node.strategies_vector[strategy_index]
+        param_to_shard_dict = dict(model_to_shard.named_parameters())
+        param_to_compare_dict = dict(model_to_compare.named_parameters())
+        for name in param_to_shard_dict.keys():
+            param_name = name.split('.')[-1]
+            param_sharding_spec = strategy_in_use.get_sharding_spec_by_name(param_name)
+            grad_sharded = param_to_shard_dict[name].grad
+            grad_to_compare = param_to_compare_dict[name].grad
+            global_grad = to_global(grad_sharded, param_sharding_spec)
+            assert_close((global_grad - grad_to_compare).sum(), zero_tensor)