diff --git a/colossalai/auto_parallel/solver/cost_graph.py b/colossalai/auto_parallel/solver/cost_graph.py
new file mode 100644
index 000000000..a67ac1c3f
--- /dev/null
+++ b/colossalai/auto_parallel/solver/cost_graph.py
@@ -0,0 +1,131 @@
+from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
+from typing import List
+from torch.fx.node import Node
+
+
+class CostGraph:
+    '''
+    A graph data structure to simplify the edge cost graph. It has two main functions:
+    1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
+    CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
+    2. To reduce the searching space, we merge computationally-trivial operators, such as 
+    element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
+    be given by the StrategiesVector depending on the type of target node and following nodes.
+
+    Argument:
+        leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
+        simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
+    '''
+
+    def __init__(self, leaf_strategies, simplify=True):
+        self.leaf_strategies = leaf_strategies
+        # stores number of strategies in each node
+        self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies}
+        # extra_node_costs will store the extra costs introduced by merging nodes
+        self.extra_node_costs = {}
+        self.simplify = simplify
+        self._build_cost_graph()
+
+    def _build_cost_graph(self):
+        '''
+        This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
+        set to node.
+        '''
+        self.edge_costs = {}
+        if self.simplify:
+            self.merge_pair = []
+        for strategies_vector in self.leaf_strategies:
+            # build edge_cost
+            dst_node = strategies_vector.node
+            for src_node in strategies_vector.predecessor_nodes:
+                node_pair = (src_node, dst_node)
+                src_index = strategies_vector.predecessor_nodes.index(src_node)
+                edge_cost = {}
+                for i in range(len(strategies_vector)):
+                    for j in range(len(src_node.stategy_vector)):
+                        edge_cost[(i, j)] = strategies_vector[i].resharding_costs[src_index][j]
+                self.edge_costs[node_pair] = edge_cost
+            # add parents and children attribute to node
+            setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
+            setattr(dst_node, 'children', strategies_vector.successor_nodes)
+
+            if self.simplify and strategies_vector.check_merge():
+                for following_node in strategies_vector.successor_nodes:
+                    self.merge_pair.append((dst_node, following_node))
+
+    def get_edge_cost(self, src_node, dst_node):
+        return self.edge_costs[(src_node, dst_node)]
+
+    def merge_node(self, src_node, dst_node):
+        '''
+        To merge src_node into dst_node, we need to do it in following steps:
+        
+        1. For each strategy in dst_node, we need to pick an appropriate strategy
+        of src_node to merge, it is important because the logical resharding costs 
+        between the parents node of src_node and merged node depend on the src_node 
+        strategies dispatching. For example, for the graph 0->1->2, after merging node 1
+        into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)]
+        x represents the picking strategy of node 1 merged into node 2 strategy 0.
+        
+        2. We need to accumulate the extra costs introduced by merging nodes, the extra costs
+        contains two parts, one is resharding costs between src_node strategy and dst_node strategy,
+        another is the origin extra costs in src_node strategy.
+
+        3. Build connections between new node pairs, and remove the src_node after all consumer nodes
+        detached from it.
+
+        Argument:
+            src_node(Node): The node will be merged into dst_node.
+            dst_node(Node): The node to integrate src_node.
+        '''
+        src_node_index = dst_node.parents.index(src_node)
+        # build merge_map
+        merge_map = {}
+        for dst_strate_index, strategy in enumerate(dst_node.strategies_vector):
+            resharding_costs = strategy.resharding_costs
+            resharding_cost_for_src = resharding_costs[src_node_index]
+            lowest_cost_index = resharding_cost_for_src.index(min(resharding_cost_for_src))
+            merge_map[dst_strate_index] = lowest_cost_index
+
+        # extra_node_cost for dst node
+        extra_node_costs[dst_node] = [0.0 for _ in range(self.node_lens[dst_node])]
+        for dst_strate_index, strategy in enumerate(dst_node.strategies_vector):
+            target_strate_index = merge_map[dst_strate_index]
+            extra_node_costs[dst_node][dst_strate_index] += strategy.resharding_costs[src_node_index][
+                target_strate_index]
+            if src_node in extra_node_costs:
+                extra_node_costs[dst_node][dst_strate_index] += extra_node_costs[src_node][target_strate_index]
+
+        # connect dst node and parents of src node
+        dst_node.parents.remove(src_node)
+        src_node.children.remove(dst_node)
+        node_pair_to_remove = [(src_node, dst_node)]
+        for parent_node in src_node.parents:
+            if parent_node not in dst_node.parents:
+                dst_node.parents.append(parent)
+            if dst_node not in parent_node.children:
+                parent_node.children.append(dst_node)
+            # remove src node from cost graph when src node has no consumer.
+            if len(src_node.children) == 0:
+                parent_node.children.remove(src_node)
+                node_pair = (parent_node, src_node)
+                self.edge_costs.pop(node_pair)
+
+        # add new node pair to cost graph
+        for parent_node in src_node.parents:
+            new_node_pair = (parent_node, dst_node)
+            old_node_pair = (parent_node, src_node)
+            if new_node_pair in self.edge_costs:
+                continue
+            edge_cost = {}
+            for i in range(self.node_lens[dst_node]):
+                for j in range(self.node_lens[parent_node]):
+                    src_strate_index = merge_map[i]
+                    edge_cost[(i, j)] = self.edge_costs[old_node_pair][(j, src_strate_index)]
+            self.edge_costs[new_node_pair] = edge_cost
+
+    def simplify_graph(self):
+        if not self.simplify:
+            return
+        for (src_node, dst_node) in self.merge_pair:
+            self.merge_node(src_node, dst_node)
diff --git a/colossalai/auto_parallel/solver/operator_handler.py b/colossalai/auto_parallel/solver/operator_handler.py
index 63d17e6cb..1cacc9324 100644
--- a/colossalai/auto_parallel/solver/operator_handler.py
+++ b/colossalai/auto_parallel/solver/operator_handler.py
@@ -84,6 +84,7 @@ class OperatorHandler(ABC):
         for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input):
             resharding_costs[input_node] = []
             for strategy in input_node.strategies_vector:
-                _, _, resharding_cost = self.shape_consistency_manager.shape_consistency(strategy, input_spec)
+                _, _, resharding_cost = self.shape_consistency_manager.shape_consistency(
+                    strategy.output_sharding_spec, input_spec)
                 resharding_costs[input_node].append(resharding_cost)
         return resharding_cost
diff --git a/colossalai/auto_parallel/solver/sharding_strategy.py b/colossalai/auto_parallel/solver/sharding_strategy.py
index b6eb2e220..870d7e8dd 100644
--- a/colossalai/auto_parallel/solver/sharding_strategy.py
+++ b/colossalai/auto_parallel/solver/sharding_strategy.py
@@ -47,7 +47,7 @@ class StrategiesVector(list):
         self.node = node
         # fetch its input and output nodes
         self.predecessor_nodes = list(node._input_nodes.keys())
-        self.successor_ndoes = list(node.users.keys())
+        self.successor_nodes = list(node.users.keys())
 
     def check_merge(self):
         pass
diff --git a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
index f15621477..00c434e03 100644
--- a/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
+++ b/colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
@@ -15,7 +15,7 @@ def torch_matmul(input, other, *, out=None):
         shape = (input.size(0), other.size(1))
     elif d1 == 1 and d2 == 2:
         shape = (other.size(1),)
-    elif d1 == 2 and d1 == 1:
+    elif d1 == 2 and d2 == 1:
         shape = (input.size(0),)
     else:
         max_length = max(input.dim(), other.dim())
diff --git a/tests/test_auto_parallel/test_conv_handler.py b/tests/test_auto_parallel/test_conv_handler.py
index 13ec9a16f..3cda3bd80 100644
--- a/tests/test_auto_parallel/test_conv_handler.py
+++ b/tests/test_auto_parallel/test_conv_handler.py
@@ -70,7 +70,9 @@ def test_conv_handler():
             sharding_spec = ShardingSpec(device_mesh=device_mesh,
                                          entire_shape=entire_shape,
                                          sharding_sequence=sharding_sequence)
-            strategies_vector_for_input.append(sharding_spec)
+            strategy_name = str(sharding_spec.sharding_sequence)
+            sharding_strategy = ShardingStrategy(name=strategy_name, output_sharding_spec=sharding_spec)
+            strategies_vector_for_input.append(sharding_strategy)
     setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
 
     # generate conv strategy
diff --git a/tests/test_auto_parallel/test_dot_handler.py b/tests/test_auto_parallel/test_dot_handler.py
index f85546b15..4cc41178d 100644
--- a/tests/test_auto_parallel/test_dot_handler.py
+++ b/tests/test_auto_parallel/test_dot_handler.py
@@ -69,7 +69,9 @@ def test_dot_handler():
             sharding_spec = ShardingSpec(device_mesh=device_mesh,
                                          entire_shape=entire_shape,
                                          sharding_sequence=sharding_sequence)
-            strategies_vector_for_input.append(sharding_spec)
+            strategy_name = str(sharding_spec.sharding_sequence)
+            sharding_strategy = ShardingStrategy(name=strategy_name, output_sharding_spec=sharding_spec)
+            strategies_vector_for_input.append(sharding_strategy)
     setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
 
     # generate dot strategy