diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py
index 92916118b..0b898a43e 100644
--- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py
+++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py
@@ -230,7 +230,12 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
                 new_slice_items = []
 
                 for slice_item in getitem_index:
+                    if slice_item is None:
+                        new_slice_items.append(None)
+                        continue
+
                     new_start, new_stop, new_step = slice_item.start, slice_item.stop, slice_item.step
+
                     if slice_item.start in node_pairs:
                         new_start = node_pairs[slice_item.start]
                     elif slice_item.stop in node_pairs:
@@ -355,7 +360,10 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
     for node in nodes:
         if node.op == 'call_module':
             target_module = node.graph.owning_module.get_submodule(node.target)
-
+            # TODO: we need to do more actions to take care of the shared parameters.
+            if hasattr(target_module, 'processed') and target_module.processed:
+                continue
+            setattr(target_module, 'processed', True)
             for name, param in target_module.named_parameters():
                 target_sharding_spec = node.best_strategy.get_sharding_spec_by_name(name)
                 # apply the sharding spec of parameters
@@ -404,7 +412,9 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
                 target_module = root
                 target = getattr(root, atoms[0])
             else:
-                target_module = root.get_submodule(atoms[-2])
+                target_module = root
+                for atom in atoms[:-1]:
+                    target_module = getattr(target_module, atom)
                 target = getattr(target_module, atoms[-1])
 
             target_sharding_spec = node.sharding_spec
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/__init__.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_related_module.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py
similarity index 64%
rename from tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_related_module.py
rename to tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py
index 82accebdb..b66ad1949 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_solver_with_gpt_related_module.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/gpt_modules.py
@@ -2,32 +2,30 @@ from typing import Optional, Tuple, Union
 
 import torch
 import torch.nn as nn
-import transformers
-from torch.fx import GraphModule
-from transformers.models.gpt2.modeling_gpt2 import (
-    GPT2MLP,
-    BaseModelOutputWithPastAndCrossAttentions,
-    GPT2PreTrainedModel,
-)
+from transformers.activations import ACT2FN
+from transformers.models.gpt2.modeling_gpt2 import BaseModelOutputWithPastAndCrossAttentions, GPT2PreTrainedModel
 from transformers.pytorch_utils import Conv1D
 
-from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
-from colossalai.auto_parallel.tensor_shard.solver import (
-    CostGraph,
-    GraphAnalyser,
-    Solver,
-    SolverOptions,
-    StrategiesConstructor,
-)
-from colossalai.device.device_mesh import DeviceMesh
-from colossalai.fx.tracer.tracer import ColoTracer
-from colossalai.tensor.shape_consistency import ShapeConsistencyManager
-from colossalai.testing import parameterize
-from colossalai.testing.pytest_wrapper import run_on_environment_flag
 
-BATCH_SIZE = 1
-SEQ_LENGTH = 32
-HIDDEN_DIM = 768
+class GPT2MLP(nn.Module):
+
+    def __init__(self, intermediate_size, config):
+        super().__init__()
+        embed_dim = config.hidden_size
+        self.c_fc = Conv1D(intermediate_size, embed_dim)
+        self.c_proj = Conv1D(embed_dim, intermediate_size)
+        self.act = ACT2FN[config.activation_function]
+        # We temporarily banned the Dropout layer because the rng state need
+        # to process to get the correct result.
+        # self.dropout = nn.Dropout(config.resid_pdrop)
+
+    def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
+        hidden_states = self.c_fc(hidden_states)
+        hidden_states = self.act(hidden_states)
+        hidden_states = self.c_proj(hidden_states)
+        # TODO: the rng state need to be fixed for distributed runtime
+        # hidden_states = self.dropout(hidden_states)
+        return hidden_states
 
 
 # The reason Why we don't import GPT2Attention from transformers directly is that:
@@ -89,7 +87,7 @@ class GPT2Attention(nn.Module):
 
         # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
         attn_weights = attn_weights.type(value.dtype)
-        attn_weights = self.attn_dropout(attn_weights)
+        # attn_weights = self.attn_dropout(attn_weights)
 
         # Mask heads if we want to
         if head_mask is not None:
@@ -125,15 +123,10 @@ class GPT2Attention(nn.Module):
         present = (key, value)
 
         attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
-
         attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
         attn_output = self.c_proj(attn_output)
-        attn_output = self.resid_dropout(attn_output)
-
-        outputs = (attn_output, present)
-        outputs += (attn_weights,)
-
-        return outputs    # a, present, (attentions)
+        # attn_output = self.resid_dropout(attn_output)
+        return attn_output
 
 
 class GPT2Block(nn.Module):
@@ -161,19 +154,15 @@ class GPT2Block(nn.Module):
             attention_mask=attention_mask,
             head_mask=head_mask,
         )
-        attn_output = attn_outputs[0]    # output_attn: a, present, (attentions)
-        outputs = attn_outputs[1:]
         # residual connection
-        hidden_states = attn_output + residual
+        hidden_states = attn_outputs + residual
         residual = hidden_states
         hidden_states = self.ln_2(hidden_states)
         feed_forward_hidden_states = self.mlp(hidden_states)
         # residual connection
         hidden_states = residual + feed_forward_hidden_states
 
-        outputs = (hidden_states,) + outputs[1:]
-
-        return outputs    # hidden_states, present, (attentions, cross_attentions)
+        return hidden_states
 
 
 class GPT2Model(GPT2PreTrainedModel):
@@ -228,103 +217,25 @@ class GPT2Model(GPT2PreTrainedModel):
         # attention_probs has shape bsz x n_heads x N x N
         # head_mask has shape n_layer x batch x n_heads x N x N
         head_mask = self.get_head_mask(head_mask, self.config.n_layer)
-
         inputs_embeds = self.wte(input_ids)
         position_embeds = self.wpe(position_ids)
+
         # add_2
         hidden_states = inputs_embeds + position_embeds
 
         token_type_embeds = self.wte(token_type_ids)
         hidden_states = hidden_states + token_type_embeds
 
-        # transformer_drop
-        hidden_states = self.drop(hidden_states)
         # comment to run pipeline
         # add_3
         output_shape = input_shape + (hidden_states.size(-1),)
 
-        presents = None
-        all_self_attentions = None
-        all_cross_attentions = None
-        all_hidden_states = None
         for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
             outputs = block(hidden_states, attention_mask=attention_mask, head_mask=head_mask[i])
-            hidden_states = outputs[0]
+            hidden_states = outputs
 
         hidden_states = self.ln_f(hidden_states)
         # comment to run pipeline
         hidden_states = hidden_states.view(output_shape)
 
-        return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
-                     if v is not None)
-
-
-@run_on_environment_flag(name='AUTO_PARALLEL')
-@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
-def test_self_attention_block(model_cls):
-    config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
-    if model_cls == GPT2MLP:
-        model = model_cls(intermediate_size=4 * config.hidden_size, config=config)
-    else:
-        model = model_cls(config=config)
-    physical_mesh_id = torch.arange(0, 4)
-    mesh_shape = (2, 2)
-    # [[0, 1]
-    #  [2, 3]]
-    device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
-    shape_consistency_manager = ShapeConsistencyManager()
-
-    tracer = ColoTracer()
-    if model_cls == GPT2MLP:
-        input_sample = {
-            'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
-        }
-    elif model_cls in (GPT2Attention, GPT2Block):
-        input_sample = {
-            'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
-            'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'),
-        }
-    else:
-        input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
-        token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
-        attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
-        kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
-        input_sample = {k: v.to('meta') for k, v in kwargs.items()}
-
-    graph = tracer.trace(root=model, meta_args=input_sample)
-
-    gm = GraphModule(model, graph, model.__class__.__name__)
-    print(gm.graph)
-    gm.recompile()
-    graph_analyser = GraphAnalyser(gm)
-    liveness_list = graph_analyser.liveness_analysis()
-    solver_options = SolverOptions()
-    strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
-    strategies_constructor.build_strategies_and_cost()
-
-    cost_graph = CostGraph(strategies_constructor.leaf_strategies)
-    cost_graph.simplify_graph()
-    solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1)
-    ret = solver.call_solver_serialized_args()
-    strategies_list = solver.last_s_val
-    nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
-
-    computation_cost = 0
-    communication_cost = 0
-    memory_cost = 0
-    for index, node in enumerate(nodes):
-        print(node.name, node.strategies_vector[strategies_list[index]].name)
-        computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total
-        communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total
-        node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total
-        if isinstance(node_memory_cost, tuple):
-            node_memory_cost = node_memory_cost[0]
-        memory_cost += node_memory_cost.activation + node_memory_cost.parameter
-
-    print(f'computation cost is {computation_cost}')
-    print(f'communication cost is {communication_cost}')
-    print(f'memory cost is {memory_cost}')
-
-
-if __name__ == '__main__':
-    test_self_attention_block()
+        return hidden_states
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gptmlp_runtime.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py
similarity index 51%
rename from tests/test_auto_parallel/test_tensor_shard/test_gptmlp_runtime.py
rename to tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py
index d573c6590..361c22d26 100644
--- a/tests/test_auto_parallel/test_tensor_shard/test_gptmlp_runtime.py
+++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_runtime_with_gpt_modules.py
@@ -1,7 +1,7 @@
 import copy
 import random
 from functools import partial
-from typing import Optional, Tuple, Union
+from typing import Dict, Optional, Tuple, Union
 
 import numpy as np
 import pytest
@@ -10,13 +10,11 @@ import torch.multiprocessing as mp
 import torch.nn as nn
 import transformers
 from torch.fx import GraphModule
-from transformers.activations import ACT2FN
-from transformers.models.gpt2.modeling_gpt2 import GPT2MLP
-from transformers.pytorch_utils import Conv1D
 
 from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
 from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
 from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
+from colossalai.auto_parallel.tensor_shard.sharding_strategy import ShardingSpec
 from colossalai.auto_parallel.tensor_shard.solver import (
     CostGraph,
     GraphAnalyser,
@@ -32,6 +30,7 @@ from colossalai.tensor.shape_consistency import ShapeConsistencyManager, to_glob
 from colossalai.testing import assert_close, assert_close_loose, parameterize, rerun_if_address_is_in_use
 from colossalai.testing.pytest_wrapper import run_on_environment_flag
 from colossalai.utils import free_port
+from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
 
 BATCH_SIZE = 1
 SEQ_LENGTH = 32
@@ -46,36 +45,73 @@ torch.backends.cudnn.deterministic = True
 torch.backends.cudnn.benchmark = False
 
 
-class GPT2MLP(nn.Module):
+def _check_module_grad(module: torch.nn.Module, origin_param_dict: Dict[str, torch.Tensor],
+                       best_sharding_spec_dict: Dict[str, ShardingSpec]):
+    for name, param in module.named_parameters():
+        param_grad = param.grad
+        origin_param_grad = origin_param_dict[name].grad
+        atoms = name.split('.')
+        new_name = '_'.join(atoms)
+        if new_name in best_sharding_spec_dict:
+            param_sharding_spec = best_sharding_spec_dict[new_name]
+            grad_to_compare = copy.deepcopy(param_grad)
+            param_grad_global = to_global(grad_to_compare, param_sharding_spec)
 
-    def __init__(self, intermediate_size, config):
-        super().__init__()
-        embed_dim = config.hidden_size
-        self.c_fc = Conv1D(intermediate_size, embed_dim)
-        self.c_proj = Conv1D(embed_dim, intermediate_size)
-        self.act = ACT2FN[config.activation_function]
-        # We temporarily banned the Dropout layer because the rng state need
-        # to process to get the correct result.
-        # self.dropout = nn.Dropout(config.resid_pdrop)
-
-    def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
-        hidden_states = self.c_fc(hidden_states)
-        hidden_states = self.act(hidden_states)
-        hidden_states = self.c_proj(hidden_states)
-        # TODO: the rng state need to be fixed for distributed runtime
-        # hidden_states = self.dropout(hidden_states)
-        return hidden_states
+            try:
+                assert_close_loose(param_grad_global, origin_param_grad, rtol=1e-03, atol=1e-03)
+            except:
+                difference = param_grad_global - origin_param_grad
+                avg_diff = difference.abs().sum() / difference.numel()
+                assert avg_diff < 0.001
+                print(f'{name} param has {avg_diff} average difference')
 
 
-def check_mlp_layer(rank, model_cls, world_size, port):
+def check_attention_layer(rank, model_cls, world_size, port):
     disable_existing_loggers()
     launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
 
-    config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
-    model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda')
-    input = torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('cuda')
+    config = transformers.GPT2Config(n_position=64, n_layer=1, n_head=16, n_embd=HIDDEN_DIM)
+
+    if model_cls == GPT2MLP:
+        model = model_cls(intermediate_size=4 * config.hidden_size, config=config).to('cuda')
+    else:
+        model = model_cls(config=config).to('cuda')
     test_model = copy.deepcopy(model)
-    test_input = copy.deepcopy(input)
+
+    input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+    token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+    attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+    hidden_states = torch.rand((BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM), dtype=torch.float32)
+
+    if model_cls == GPT2MLP:
+        input_sample = (hidden_states.to('cuda'),)
+        test_input_sample = copy.deepcopy(input_sample)
+        meta_input_sample = {
+            'hidden_states': hidden_states.to('meta'),
+        }
+    elif model_cls in (GPT2Attention, GPT2Block):
+        input_sample = (
+            hidden_states.to('cuda'),
+            attention_mask.to('cuda'),
+        )
+        test_input_sample = copy.deepcopy(input_sample)
+        meta_input_sample = {
+            'hidden_states': hidden_states.to('meta'),
+            'attention_mask': attention_mask.to('meta'),
+        }
+    else:
+        input_sample = (
+            input_ids.to('cuda'),
+            token_type_ids.to('cuda'),
+            attention_mask.to('cuda'),
+        )
+        test_input_sample = copy.deepcopy(input_sample)
+        meta_input_sample = {
+            'input_ids': input_ids.to('meta'),
+            'token_type_ids': token_type_ids.to('meta'),
+            'attention_mask': attention_mask.to('meta'),
+        }
+
     physical_mesh_id = torch.arange(0, 4)
     mesh_shape = (2, 2)
     # [[0, 1]
@@ -85,15 +121,10 @@ def check_mlp_layer(rank, model_cls, world_size, port):
 
     tracer = ColoTracer()
 
-    input_sample = {
-        'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
-    }
-
-    graph = tracer.trace(root=model, meta_args=input_sample)
-    print(graph)
+    graph = tracer.trace(root=model, meta_args=meta_input_sample)
     gm = GraphModule(model, graph, model.__class__.__name__)
     gm.recompile()
-    print(gm)
+
     graph_analyser = GraphAnalyser(gm)
     liveness_list = graph_analyser.liveness_analysis()
     solver_options = SolverOptions()
@@ -110,71 +141,35 @@ def check_mlp_layer(rank, model_cls, world_size, port):
         gm, solution, device_mesh, strategies_constructor)
     gm = runtime_apply_pass(gm)
     gm.recompile()
+    nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
+    best_sharding_spec_dict = {}
+    for index, node in enumerate(nodes):
+        best_sharding_spec_dict[node.name] = node.sharding_spec
+
     cuda_rng_state = torch.cuda.get_rng_state()
     cpu_rng_state = torch.get_rng_state()
-    origin_output = test_model(test_input)
+    origin_output = test_model(*test_input_sample)
     torch.cuda.set_rng_state(cuda_rng_state)
     torch.set_rng_state(cpu_rng_state)
-    output = gm(input, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
-    assert_close(output, origin_output, rtol=1e-03, atol=1e-04)
+    output = gm(*input_sample, sharding_spec_dict, origin_spec_dict, comm_actions_dict)
+    assert_close(output, origin_output, rtol=1e-03, atol=1e-03)
 
     #*******************backward starting*******************
     cuda_rng_state = torch.cuda.get_rng_state()
+    cpu_rng_state = torch.get_rng_state()
     output.sum().backward()
+    torch.set_rng_state(cpu_rng_state)
     torch.cuda.set_rng_state(cuda_rng_state)
     origin_output.sum().backward()
     origin_param_dict = dict(test_model.named_parameters())
+
     if rank == 0:
         print("*******************backward starting*******************")
-        for name, param in model.named_parameters():
-            param_grad = param.grad
-            origin_param_grad = origin_param_dict[name].grad
-            origin_param_size = origin_param_grad.shape[-1]
-            print(name, param_grad, origin_param_grad)
-            if name == 'c_fc.bias':
-                assert_close_loose(param_grad,
-                                   origin_param_grad.narrow(0, 0, origin_param_size // 2),
-                                   rtol=1e-03,
-                                   atol=1e-03)
-            else:
-                assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03)
+
+    _check_module_grad(gm, origin_param_dict, best_sharding_spec_dict)
+
+    if rank == 0:
         print("*******************backward finished*******************")
-    if rank == 1:
-        for name, param in model.named_parameters():
-            param_grad = param.grad
-            origin_param_grad = origin_param_dict[name].grad
-            origin_param_size = origin_param_grad.shape[-1]
-            if name == 'c_fc.bias':
-                assert_close_loose(param_grad,
-                                   origin_param_grad.narrow(0, origin_param_size // 2, origin_param_size // 2),
-                                   rtol=1e-03,
-                                   atol=1e-03)
-            else:
-                assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03)
-    if rank == 2:
-        for name, param in model.named_parameters():
-            param_grad = param.grad
-            origin_param_grad = origin_param_dict[name].grad
-            origin_param_size = origin_param_grad.shape[-1]
-            if name == 'c_fc.bias':
-                assert_close_loose(param_grad,
-                                   origin_param_grad.narrow(0, 0, origin_param_size // 2),
-                                   rtol=1e-03,
-                                   atol=1e-03)
-            else:
-                assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03)
-    if rank == 3:
-        for name, param in model.named_parameters():
-            param_grad = param.grad
-            origin_param_grad = origin_param_dict[name].grad
-            origin_param_size = origin_param_grad.shape[-1]
-            if name == 'c_fc.bias':
-                assert_close_loose(param_grad,
-                                   origin_param_grad.narrow(0, origin_param_size // 2, origin_param_size // 2),
-                                   rtol=1e-03,
-                                   atol=1e-03)
-            else:
-                assert_close_loose(param_grad, origin_param_grad, rtol=1e-03, atol=1e-03)
 
     #*******************backward finished*******************
 
@@ -202,11 +197,11 @@ def check_mlp_layer(rank, model_cls, world_size, port):
 
 @run_on_environment_flag(name='AUTO_PARALLEL')
 @pytest.mark.dist
-@parameterize('model_cls', [GPT2MLP])
+@parameterize('model_cls', [GPT2MLP, GPT2Block, GPT2Attention, GPT2Model])
 @rerun_if_address_is_in_use()
 def test_mlp_layer(model_cls):
     world_size = 4
-    run_func = partial(check_mlp_layer, model_cls=model_cls, world_size=world_size, port=free_port())
+    run_func = partial(check_attention_layer, model_cls=model_cls, world_size=world_size, port=free_port())
     mp.spawn(run_func, nprocs=world_size)
 
 
diff --git a/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py
new file mode 100644
index 000000000..478b77e76
--- /dev/null
+++ b/tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py
@@ -0,0 +1,94 @@
+import torch
+import torch.nn as nn
+import transformers
+from torch.fx import GraphModule
+
+from colossalai.auto_parallel.tensor_shard.constants import BATCHNORM_MODULE_OP
+from colossalai.auto_parallel.tensor_shard.solver import (
+    CostGraph,
+    GraphAnalyser,
+    Solver,
+    SolverOptions,
+    StrategiesConstructor,
+)
+from colossalai.device.device_mesh import DeviceMesh
+from colossalai.fx.tracer.tracer import ColoTracer
+from colossalai.tensor.shape_consistency import ShapeConsistencyManager
+from colossalai.testing import parameterize
+from colossalai.testing.pytest_wrapper import run_on_environment_flag
+from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
+
+BATCH_SIZE = 1
+SEQ_LENGTH = 32
+HIDDEN_DIM = 768
+
+
+@run_on_environment_flag(name='AUTO_PARALLEL')
+@parameterize('model_cls', [GPT2Block, GPT2Attention, GPT2MLP, GPT2Model])
+def test_self_attention_block(model_cls):
+    config = transformers.GPT2Config(n_position=64, n_layer=4, n_head=16, n_embd=HIDDEN_DIM)
+    if model_cls == GPT2MLP:
+        model = model_cls(intermediate_size=4 * config.hidden_size, config=config)
+    else:
+        model = model_cls(config=config)
+    physical_mesh_id = torch.arange(0, 4)
+    mesh_shape = (2, 2)
+    # [[0, 1]
+    #  [2, 3]]
+    device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
+    shape_consistency_manager = ShapeConsistencyManager()
+
+    tracer = ColoTracer()
+    if model_cls == GPT2MLP:
+        input_sample = {
+            'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
+        }
+    elif model_cls in (GPT2Attention, GPT2Block):
+        input_sample = {
+            'hidden_states': torch.rand(BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM).to('meta'),
+            'attention_mask': torch.rand(1, SEQ_LENGTH).to('meta'),
+        }
+    else:
+        input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+        token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+        attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
+        kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
+        input_sample = {k: v.to('meta') for k, v in kwargs.items()}
+
+    graph = tracer.trace(root=model, meta_args=input_sample)
+
+    gm = GraphModule(model, graph, model.__class__.__name__)
+    print(gm.graph)
+    gm.recompile()
+    graph_analyser = GraphAnalyser(gm)
+    liveness_list = graph_analyser.liveness_analysis()
+    solver_options = SolverOptions()
+    strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
+    strategies_constructor.build_strategies_and_cost()
+
+    cost_graph = CostGraph(strategies_constructor.leaf_strategies)
+    cost_graph.simplify_graph()
+    solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=-1)
+    ret = solver.call_solver_serialized_args()
+    strategies_list = solver.last_s_val
+    nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
+
+    computation_cost = 0
+    communication_cost = 0
+    memory_cost = 0
+    for index, node in enumerate(nodes):
+        print(node.name, node.strategies_vector[strategies_list[index]].name)
+        computation_cost += node.strategies_vector[strategies_list[index]].compute_cost.total
+        communication_cost += node.strategies_vector[strategies_list[index]].communication_cost.total
+        node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost.total
+        if isinstance(node_memory_cost, tuple):
+            node_memory_cost = node_memory_cost[0]
+        memory_cost += node_memory_cost.activation + node_memory_cost.parameter
+
+    print(f'computation cost is {computation_cost}')
+    print(f'communication cost is {communication_cost}')
+    print(f'memory cost is {memory_cost}')
+
+
+if __name__ == '__main__':
+    test_self_attention_block()