diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py
index ceccb9a9f..de5e7356b 100644
--- a/colossalai/autochunk/autochunk_codegen.py
+++ b/colossalai/autochunk/autochunk_codegen.py
@@ -123,12 +123,13 @@ def _replace_name(context: str, name_from: str, name_to: str) -> str:
     """
     replace node name
     """
-    patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")")]
+    patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")"), (" ", ""), ("", " ")]
     for p in patterns:
         source = p[0] + name_from + p[1]
         target = p[0] + name_to + p[1]
         if source in context:
             context = context.replace(source, target)
+            break
     return context
 
 
@@ -138,8 +139,7 @@ def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict)
     """
     if node_name not in reshape_size_dict:
         return context
-    for size_name, size_value in reshape_size_dict[node_name].items():
-        context = context.replace(size_name, size_value)
+    context = context.replace(reshape_size_dict[node_name][0], reshape_size_dict[node_name][1])
     return context
 
 
diff --git a/colossalai/autochunk/estimate_memory.py b/colossalai/autochunk/estimate_memory.py
index e001423f1..d38625385 100644
--- a/colossalai/autochunk/estimate_memory.py
+++ b/colossalai/autochunk/estimate_memory.py
@@ -37,10 +37,10 @@ class EstimateMemory(object):
 
     def _add_active_node(self, n, active_list):
         new_active = self._get_output_node(n)[1]
-        if n.op == "placeholder":
+        if n.op == "placeholder" and get_node_shape(n) is not None:
             new_active.append(n.name)
         for i in new_active:
-            if i not in active_list:
+            if i not in active_list and get_node_shape(n) is not None:
                 active_list.append(i)
 
     def _get_delete_node(self, user, user_to_last_uses, to_keep=None):
@@ -77,15 +77,11 @@ class EstimateMemory(object):
             if i in active_list:
                 active_list.remove(i)
 
-    def _get_chunk_inputs_size(
-        self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx
-    ):
+    def _get_chunk_inputs_size(self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx):
         nodes_to_delete = []
         for chunk_input in chunk_inputs + chunk_inputs_non_chunk:
             chunk_input_users = chunk_input.users.keys()
-            chunk_input_users_idx = [
-                find_idx_by_name(i.name, node_list) for i in chunk_input_users
-            ]
+            chunk_input_users_idx = [find_idx_by_name(i.name, node_list) for i in chunk_input_users]
             if all(i <= chunk_end_idx for i in chunk_input_users_idx):
                 if chunk_input not in nodes_to_delete:
                     nodes_to_delete.append(chunk_input)
@@ -112,9 +108,7 @@ class EstimateMemory(object):
         not_contiguous_ops = ["permute"]
         inherit_contiguous_ops = ["transpose", "view"]
 
-        if node.op == "call_function" and any(
-            n in node.name for n in ["matmul", "reshape"]
-        ):
+        if node.op == "call_function" and any(n in node.name for n in ["matmul", "reshape"]):
             for n in node.args:
                 if n in not_contiguous_list:
                     # matmul won't change origin tensor, but create a tmp copy
@@ -125,9 +119,7 @@ class EstimateMemory(object):
                     # module will just make origin tensor to contiguous
                     if delete:
                         not_contiguous_list.remove(n)
-        elif node.op == "call_method" and any(
-            i in node.name for i in not_contiguous_ops
-        ):
+        elif node.op == "call_method" and any(i in node.name for i in not_contiguous_ops):
             if node not in not_contiguous_list:
                 not_contiguous_list.append(node)
         return mem
@@ -142,9 +134,7 @@ class EstimateMemory(object):
         else:
             return float(chunk_size) / node_shape[chunk_dim]
 
-    def _get_chunk_delete_node_size(
-        self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names
-    ):
+    def _get_chunk_delete_node_size(self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names):
         # if any(j in user.name for j in ['transpose', 'permute', 'view']):
         #     return 0
         if user.op in ("placeholder", "output"):
@@ -196,7 +186,7 @@ class EstimateMemory(object):
         Returns:
             act_memory_peak_log (List): peak memory of every node
             act_memory_after_node_log (List): memory after excuting every node
-            active_node_list_log (List): active nodes of every node. active nodes refer to 
+            active_node_list_log (List): active nodes of every node. active nodes refer to
                 nodes generated but not deleted.
         """
         act_memory = 0.0
@@ -212,7 +202,7 @@ class EstimateMemory(object):
         use_chunk = True if chunk_infos is not None else False
         chunk_within = False
         chunk_region_idx = None
-        chunk_ratio = 1  # use it to estimate chunk mem
+        chunk_ratio = 1    # use it to estimate chunk mem
         chunk_inputs_names = []
 
         if use_chunk:
@@ -221,23 +211,18 @@ class EstimateMemory(object):
             chunk_ends = [i[1] for i in chunk_regions]
             chunk_inputs = [i["inputs"] for i in chunk_infos]
             chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
-            chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [
-                j.name for i in chunk_inputs_non_chunk for j in i
-            ]
+            chunk_inputs_names = [j.name for i in chunk_inputs for j in i
+                                 ] + [j.name for i in chunk_inputs_non_chunk for j in i]
             chunk_outputs = [i["outputs"][0] for i in chunk_infos]
             chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos]
-            chunk_sizes = [
-                i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos
-            ]
+            chunk_sizes = [i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos]
 
         for idx, node in enumerate(node_list):
             # if node in chunk start nodes, change chunk ratio and add chunk_tensor
             if use_chunk and idx in chunk_starts:
                 chunk_within = True
                 chunk_region_idx = chunk_starts.index(idx)
-                act_memory += self._get_output_node_size(
-                    chunk_outputs[chunk_region_idx]
-                ) / (1024**2)
+                act_memory += self._get_output_node_size(chunk_outputs[chunk_region_idx]) / (1024**2)
 
             # determine chunk ratio for current node
             if chunk_within:
@@ -262,22 +247,13 @@ class EstimateMemory(object):
             else:
                 # forward memory
                 # TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose
-                act_memory += (
-                    self._get_contiguous_memory(node, not_contiguous_list)
-                    * chunk_ratio
-                    / (1024**2)
-                )
-                act_memory += (
-                    self._get_output_node_size(node) * chunk_ratio / (1024**2)
-                )
+                act_memory += (self._get_contiguous_memory(node, not_contiguous_list) * chunk_ratio / (1024**2))
+                act_memory += (self._get_output_node_size(node) * chunk_ratio / (1024**2))
                 # record max act memory
                 act_memory_peak_log.append(act_memory)
                 # delete useless memory
-                act_memory -= (
-                    self._get_contiguous_memory(node, not_contiguous_list, delete=True)
-                    * chunk_ratio
-                    / (1024**2)
-                )
+                act_memory -= (self._get_contiguous_memory(node, not_contiguous_list, delete=True) * chunk_ratio /
+                               (1024**2))
                 # delete unused vars not in chunk_input_list
                 # we can't delete input nodes until chunk ends
                 if chunk_within:
@@ -288,9 +264,8 @@ class EstimateMemory(object):
                         chunk_inputs_names,
                     ) / (1024**2)
                 else:
-                    act_memory -= self._get_delete_node_size(
-                        node, user_to_last_uses_no_free_var, chunk_inputs_names
-                    ) / (1024**2)
+                    act_memory -= self._get_delete_node_size(node, user_to_last_uses_no_free_var,
+                                                             chunk_inputs_names) / (1024**2)
 
             # log active node, only effective without chunk
             self._add_active_node(node, active_node_list)
@@ -298,9 +273,7 @@ class EstimateMemory(object):
 
             # if node in chunk end nodes, restore chunk settings
             if use_chunk and idx in chunk_ends:
-                act_memory -= (
-                    self._get_output_node_size(node) * chunk_ratio / (1024**2)
-                )
+                act_memory -= (self._get_output_node_size(node) * chunk_ratio / (1024**2))
                 act_memory -= self._get_chunk_inputs_size(
                     chunk_inputs[chunk_region_idx],
                     chunk_inputs_non_chunk[chunk_region_idx],
diff --git a/colossalai/autochunk/search_chunk.py b/colossalai/autochunk/search_chunk.py
index c9e5e5172..236f9697d 100644
--- a/colossalai/autochunk/search_chunk.py
+++ b/colossalai/autochunk/search_chunk.py
@@ -8,11 +8,7 @@ from .reorder_graph import ReorderGraph
 from .select_chunk import SelectChunk
 from .trace_flow import TraceFlow
 from .trace_indice import TraceIndice
-from .utils import (
-    get_node_shape,
-    is_non_compute_node,
-    is_non_compute_node_except_placeholder,
-)
+from .utils import get_node_shape, is_non_compute_node, is_non_compute_node_except_placeholder
 
 
 class SearchChunk(object):
@@ -73,13 +69,11 @@ class SearchChunk(object):
         """
         free_var_idx = []
         for idx, n in enumerate(self.trace_indice.node_list):
-            if n.op == "placeholder":
+            if n.op == "placeholder" and get_node_shape(n) is not None:
                 free_var_idx.append(idx)
         return free_var_idx
 
-    def _search_max_chunk_region(
-        self, active_node: List, peak_node: Node, chunk_regions: List
-    ) -> Tuple:
+    def _search_max_chunk_region(self, active_node: List, peak_node: Node, chunk_regions: List) -> Tuple:
         """
         Search max chunk region according to peak memory node
 
@@ -124,15 +118,9 @@ class SearchChunk(object):
             region = i["region"]
             if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
                 return None
-            elif (
-                region[0] <= chunk_region_start <= region[1]
-                and chunk_region_end > region[1]
-            ):
+            elif (region[0] <= chunk_region_start <= region[1] and chunk_region_end > region[1]):
                 chunk_region_start = region[1] + 1
-            elif (
-                region[0] <= chunk_region_end <= region[1]
-                and chunk_region_start < region[0]
-            ):
+            elif (region[0] <= chunk_region_end <= region[1] and chunk_region_start < region[0]):
                 chunk_region_end = region[0] - 1
         return chunk_region_start, chunk_region_end
 
@@ -164,25 +152,16 @@ class SearchChunk(object):
             for start_node, start_trace in start_traces.items():
                 for start_dim, _ in enumerate(start_trace["indice"]):
                     # dim size cannot be 1
-                    if (
-                        get_node_shape(end_node)[end_dim] == 1
-                        or get_node_shape(start_node)[start_dim] == 1
-                    ):
+                    if (get_node_shape(end_node)[end_dim] == 1 or get_node_shape(start_node)[start_dim] == 1):
                         continue
                     # check index source align
-                    if not self.trace_flow.check_index_source(
-                        start_dim, start_node, start_idx, end_dim, end_node
-                    ):
+                    if not self.trace_flow.check_index_source(start_dim, start_node, start_idx, end_dim, end_node):
                         continue
                     # check index copmute
-                    if not self.trace_flow.check_index_compute(
-                        start_idx, end_dim, end_node, end_idx
-                    ):
+                    if not self.trace_flow.check_index_compute(start_idx, end_dim, end_node, end_idx):
                         continue
                     # flow search
-                    chunk_info = self.trace_flow.flow_search(
-                        start_idx, start_dim, end_idx, end_dim
-                    )
+                    chunk_info = self.trace_flow.flow_search(start_idx, start_dim, end_idx, end_dim)
                     if chunk_info is None:
                         continue
                     # check index copmute
@@ -191,9 +170,7 @@ class SearchChunk(object):
                     chunk_infos.append(chunk_info)
         return chunk_infos
 
-    def _search_possible_chunk_regions(
-        self, max_chunk_region: Tuple, peak_node: Node
-    ) -> List:
+    def _search_possible_chunk_regions(self, max_chunk_region: Tuple, peak_node: Node) -> List:
         """
         Search every possible region within the max chunk region.
 
@@ -206,28 +183,23 @@ class SearchChunk(object):
         """
         possible_chunk_region = []
         output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
-        input_trace = []  # trace of a node's input nodes
+        input_trace = []    # trace of a node's input nodes
         for _, n in enumerate(self.trace_indice.node_list):
             cur_trace = {}
             for arg in n.args:
-                if type(arg) == type(n) and not is_non_compute_node_except_placeholder(
-                    arg
-                ):
+                if type(arg) == type(n) and not is_non_compute_node_except_placeholder(arg):
                     cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
             input_trace.append(cur_trace)
 
         for start_idx in range(max_chunk_region[0], peak_node + 1):
             for end_idx in range(peak_node, max_chunk_region[1] + 1):
                 # skip non compute nodes
-                if is_non_compute_node(
-                    self.trace_indice.node_list[start_idx]
-                ) or is_non_compute_node(self.trace_indice.node_list[end_idx]):
+                if is_non_compute_node(self.trace_indice.node_list[start_idx]) or is_non_compute_node(
+                        self.trace_indice.node_list[end_idx]):
                     continue
 
                 # select free dim
-                chunk_info = self._find_chunk_info(
-                    input_trace, output_trace, start_idx, end_idx
-                )
+                chunk_info = self._find_chunk_info(input_trace, output_trace, start_idx, end_idx)
                 if len(chunk_info) > 0:
                     possible_chunk_region.extend(chunk_info)
         return possible_chunk_region
@@ -256,17 +228,12 @@ class SearchChunk(object):
             best_chunk_region (Dict)
         """
         peak_node = self._find_peak_node(mem_peak)
-        max_chunk_region = self._search_max_chunk_region(
-            active_node, peak_node, chunk_infos
-        )
+        max_chunk_region = self._search_max_chunk_region(active_node, peak_node, chunk_infos)
         if max_chunk_region == None:
             return None
-        possible_chunk_regions = self._search_possible_chunk_regions(
-            max_chunk_region, peak_node
-        )
-        best_chunk_region = self.select_chunk._select_best_chunk_region(
-            possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
-        )
+        possible_chunk_regions = self._search_possible_chunk_regions(max_chunk_region, peak_node)
+        best_chunk_region = self.select_chunk._select_best_chunk_region(possible_chunk_regions, chunk_infos, peak_node,
+                                                                        max_chunk_region, mem_peak)
         best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
         return best_chunk_region
 
@@ -291,9 +258,7 @@ class SearchChunk(object):
             init_mem_peak,
             _,
             active_node,
-        ) = self.estimate_memory.estimate_chunk_inference_mem(
-            self.trace_indice.node_list
-        )
+        ) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list)
         mem_peak = init_mem_peak
 
         while True:
@@ -306,14 +271,10 @@ class SearchChunk(object):
                 mem_peak,
                 _,
                 active_node,
-            ) = self.estimate_memory.estimate_chunk_inference_mem(
-                self.trace_indice.node_list, chunk_infos
-            )
+            ) = self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos)
             if self._stop_search(init_mem_peak, mem_peak):
                 break
         if self.print_mem:
             self.print_mem = False
-            self.estimate_memory.estimate_chunk_inference_mem(
-                self.trace_indice.node_list, chunk_infos, print_mem=True
-            )
+            self.estimate_memory.estimate_chunk_inference_mem(self.trace_indice.node_list, chunk_infos, print_mem=True)
         return chunk_infos
diff --git a/colossalai/autochunk/trace_flow.py b/colossalai/autochunk/trace_flow.py
index ec1e012be..04fa2b3bb 100644
--- a/colossalai/autochunk/trace_flow.py
+++ b/colossalai/autochunk/trace_flow.py
@@ -1,8 +1,13 @@
+from typing import Dict, List, Tuple
+
+from torch.fx.node import Node
+
 from .trace_indice import TraceIndice
 from .utils import (
     find_chunk_all_input_nodes,
     find_chunk_compute_input_and_output_nodes,
     find_idx_by_name,
+    flat_list,
     get_node_shape,
     is_non_compute_node,
     is_non_compute_node_except_placeholder,
@@ -171,7 +176,7 @@ class TraceFlow(object):
                 # get cur node info
                 cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
                 cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
-                if cur_node_chunk_dim:
+                if cur_node_chunk_dim is not None:
                     cur_node_compute = self.trace_indice._find_compute_trace_from_node(cur_node)
                     cur_node_source = self.trace_indice._find_source_trace_from_node(cur_node)
                 else:
@@ -223,15 +228,32 @@ class TraceFlow(object):
             cur_node_list = next_node_list
         return all_node_info
 
-    def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info):
+    def _get_input_nodes_dim(self, inputs: List[Node], start_idx: int, end_idx: int, all_node_info: Dict) -> Tuple:
+        """
+        Get chunk dim for every input node for their every entry, remove unchunked nodes
+
+        Args:
+            inputs (List[Node]): input nodes
+            all_node_info (Dict): describe all node's chunk dim and fix dim
+            start_idx (int): chunk start idx
+            end_idx (int): chunk end idx
+
+        Returns:
+            inputs (List(Node)): new inputs
+            inputs_dim (List): chunk dim for inputs
+        """
         inputs_dim = []
         remove_inputs = []
         for input_node in inputs:
             input_dict = {}
             input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
             for user in input_node.users.keys():
+                # skip non compute
                 if is_non_compute_node(user):
                     continue
+                # untraced node, mostly non compute
+                if user not in all_node_info:
+                    continue
                 user_idx = find_idx_by_name(user.name, self.trace_indice.node_list)
                 if start_idx <= user_idx <= end_idx:
                     chunk_dim = all_node_info[user]["chunk_dim"]
@@ -245,12 +267,24 @@ class TraceFlow(object):
                 remove_inputs.append(input_node)
             else:
                 inputs_dim.append(input_dict)
+        # remove unchunked inputs
         for i in remove_inputs:
             if i in inputs:
                 inputs.remove(i)
         return inputs, inputs_dim
 
-    def _get_prepose_nodes(self, all_node_info, start_idx, end_idx):
+    def _get_prepose_nodes(self, all_node_info: Dict, start_idx: int, end_idx: int) -> List[Node]:
+        """
+        get all useless nodes in chunk region and prepose them
+
+        Args:
+            all_node_info (Dict): describe all node's chunk dim and fix dim
+            start_idx (int): chunk start idx
+            end_idx (int): chunk end idx
+
+        Returns:
+            List[Node]: all nodes to be preposed
+        """
         # get all possible prepose nodes
         maybe_prepose_nodes = []
         for node, node_info in all_node_info.items():
@@ -276,7 +310,7 @@ class TraceFlow(object):
                 for cur_prepose_node in tmp_cur_prepose_nodes:
                     if prepose_flag == False:
                         break
-                    for cur_prepose_node_arg in cur_prepose_node.args:
+                    for cur_prepose_node_arg in cur_prepose_node.all_input_nodes:
                         if type(cur_prepose_node_arg) != type(cur_prepose_node):
                             continue
                         # out of loop
@@ -360,19 +394,28 @@ class TraceFlow(object):
         return chunk_info
 
     def _reassgin_reshape_size(self, chunk_info):
+        """
+        Some shape args in reshape may have changed due to chunk
+        reassgin those changed shape
+        """
         chunk_region = chunk_info["region"]
         reshape_size = {}
         chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
         for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
             if any(i in node.name for i in ["reshape", "view"]):
-                reshape_args = node.args[1:]
-                reshape_log = self.trace_indice.indice_view_list[node]
+                reshape_args = flat_list(node.args[1:])
                 chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
-                reshape_size[node.name] = {}
+                new_shape = ""
                 for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
-                    if reshape_arg_dim in reshape_log["dim_to"]:
-                        continue
                     if reshape_arg_dim == chunk_dim:
-                        reshape_size[node.name][reshape_arg.name] = ("min(chunk_size, %d - chunk_idx)" % chunk_shape)
+                        new_shape += "min(chunk_size, %d - chunk_idx), " % chunk_shape
+                    else:
+                        if isinstance(reshape_arg, int):
+                            new_shape += "%s, " % str(reshape_arg)
+                        else:
+                            new_shape += "%s, " % reshape_arg.name
+                new_shape = new_shape[:-2]
+                origin_shape = str(reshape_args)[1:-1]
+                reshape_size[node.name] = [origin_shape, new_shape]
         chunk_info["reshape_size"] = reshape_size
         return chunk_info
diff --git a/colossalai/autochunk/trace_indice.py b/colossalai/autochunk/trace_indice.py
index 5a5d15e0a..862cd6b99 100644
--- a/colossalai/autochunk/trace_indice.py
+++ b/colossalai/autochunk/trace_indice.py
@@ -3,7 +3,7 @@ from typing import Dict, List, Tuple
 
 from torch.fx.node import Node
 
-from .utils import find_first_tensor_arg, find_idx_by_name, get_node_shape, unflat_list
+from .utils import find_first_tensor_arg, find_idx_by_name, flat_list, get_node_shape
 
 
 class TraceIndice(object):
@@ -28,7 +28,7 @@ class TraceIndice(object):
         node_list (List)
     """
 
-    def __init__(self, node_list: List) -> None:
+    def __init__(self, node_list: List[Node]) -> None:
         self.node_list = node_list
         self.indice_trace_list = self._init_indice_trace_list()
         self.indice_view_list = {}
@@ -198,7 +198,7 @@ class TraceIndice(object):
         node_idx = find_idx_by_name(node.name, self.node_list)
         return self.indice_trace_list[node_idx]["compute"]
 
-    def _assign_indice_as_input(self, node, node_idx, input_node=None):
+    def _assign_indice_as_input(self, node: Node, node_idx: int, input_node=None):
         """
         Assign node's trace as its input node.
 
@@ -216,7 +216,7 @@ class TraceIndice(object):
 
         self._inherit_all_computation(input_node, node)
 
-    def _assign_all_indice(self, node, node_idx):
+    def _assign_all_indice(self, node: Node, node_idx: int):
         """
         Add new indice for all node's dims.
 
@@ -232,7 +232,7 @@ class TraceIndice(object):
             new_trace.append(self._add_indice())
         self.indice_trace_list[node_idx]["indice"] = new_trace
 
-    def _assign_transpose_indice(self, node, node_idx):
+    def _assign_transpose_indice(self, node: Node, node_idx: int):
         """
         Assign indice for transpose op.
         1. swap input's dim according to transpose args
@@ -249,7 +249,7 @@ class TraceIndice(object):
         self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0])
         self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1])
 
-    def _assign_permute_indice(self, node, node_idx):
+    def _assign_permute_indice(self, node: Node, node_idx: int):
         """
         Assign indice for permute op.
         1. swap input's dim according to permute args
@@ -259,14 +259,14 @@ class TraceIndice(object):
             node (node)
             node_idx (int)
         """
-        permute_dim = unflat_list(node.args[1:])
+        permute_dim = flat_list(node.args[1:])
         input_node = node.args[0]
 
         self._assign_indice_as_input(node, node_idx, input_node)
         for idx, d in enumerate(permute_dim):
             self._inherit_indice(input_node, d, node, idx)
 
-    def _assign_linear_indice(self, node, node_idx):
+    def _assign_linear_indice(self, node: Node, node_idx: int):
         """
         Assign indice for linear op.
         1. copy trace from input node and change last indice accroding to weight
@@ -287,7 +287,7 @@ class TraceIndice(object):
 
         self._mark_computation(node, node_idx, [-1])
 
-    def _assign_matmul_indice(self, node, node_idx):
+    def _assign_matmul_indice(self, node: Node, node_idx: int):
         """
         Assign indice for matmul op.
         1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length)
@@ -393,7 +393,7 @@ class TraceIndice(object):
         self._assign_indice_as_input(node, idx)
         self._mark_computation(node, idx, [node.kwargs["dim"]])
 
-    def _assign_unsqueeze_indice(self, node, node_idx):
+    def _assign_unsqueeze_indice(self, node: Node, node_idx: int):
         """
         Assign indice for unsqueeze op.
         1. assign new indice for unsqueeze dim
@@ -404,9 +404,13 @@ class TraceIndice(object):
         """
         self._del_dim(node_idx, -1)
         self._assign_indice_as_input(node, node_idx)
-        self._add_dim(node_idx, node.args[1])
+        dim_idx = node.args[1]
+        # unsqueeze(-1) = unsqueeze(shape_num + 1)
+        if dim_idx < 0:
+            dim_idx = list(range(len(get_node_shape(node))))[dim_idx]
+        self._add_dim(node_idx, dim_idx)
 
-    def _assign_dropout_indice(self, node, node_idx):
+    def _assign_dropout_indice(self, node: Node, node_idx: int):
         """
         Assign indice for unsqueeze op.
         1. assign new indice for unsqueeze dim
@@ -417,7 +421,7 @@ class TraceIndice(object):
         """
         self._assign_indice_as_input(node, node_idx)
 
-    def _assign_ones_like_indice(self, node, node_idx):
+    def _assign_ones_like_indice(self, node: Node, node_idx: int):
         """
         Assign indice for oneslike op.
         1. assign new indice for all dim
@@ -428,7 +432,47 @@ class TraceIndice(object):
         """
         self._assign_all_indice(node, node_idx)
 
-    def _assign_view_reshape_indice(self, node, node_idx):
+    def _assign_getitem_indice(self, node: Node, node_idx: int):
+        """
+        Assign indice for getitem.
+        getitem can act like slice sometimes
+
+        Args:
+            node (node)
+            node_idx (int)
+        """
+        node_args = flat_list(node.args[1:])
+        if not any(i == str(node_arg) for i in ["None", "Ellipsis"] for node_arg in node_args):
+            return
+
+        # node args should be like [Ellipsis, slice(start, step, end), None]
+        node_shape = get_node_shape(node)
+        origin_idx_count = 0
+        new_idx_count = 0
+        new_dim_num = sum([1 if str(i) == "None" else 0 for i in node_args])
+        for _ in range(new_dim_num):
+            self._del_dim(node_idx, 0)
+        self._assign_indice_as_input(node, node_idx)
+
+        for _, node_arg in enumerate(node_args):
+            node_arg_str = str(node_arg)
+            # Ellipsis means [..., ]
+            if "Ellipsis" == node_arg_str:
+                shape_gap = len(node_shape) - len(node_args) + 1
+                origin_idx_count += shape_gap
+                new_idx_count += shape_gap
+            # slice(None, None, None) means all indexes, doesn't support other slice
+            elif "slice(None, None, None)" == node_arg_str:
+                origin_idx_count += 1
+                new_idx_count += 1
+            # None means a new dim
+            elif "None" == node_arg_str:
+                self._add_dim(node_idx, new_idx_count)
+                new_idx_count += 1
+            else:
+                raise NotImplementedError()
+
+    def _assign_view_reshape_indice(self, node: Node, node_idx: int):
         """
         Assign indice for view and reshape op.
         1. get origin shape and target shape by meta info.
@@ -447,7 +491,7 @@ class TraceIndice(object):
         origin_node = node.args[0]
         origin_shape = origin_node.meta["tensor_meta"].shape
         target_shape = []
-        unflated_args = unflat_list(node.args)
+        unflated_args = flat_list(node.args)
         for i in range(1, len(unflated_args)):
             if isinstance(unflated_args[i], int):
                 target_shape.append(unflated_args[i])
@@ -544,6 +588,8 @@ class TraceIndice(object):
                     self._assign_einsum_indice(node, idx)
                 elif "layer_norm" in node.name:
                     self._assign_layernorm_indice(node, idx)
+                elif "getitem" in node.name:
+                    self._assign_getitem_indice(node, idx)
                 elif any(i in node.name for i in ["getattr", "getitem", "eq", "_assert"]):
                     continue
                 else:
diff --git a/colossalai/autochunk/utils.py b/colossalai/autochunk/utils.py
index 5f3ea3bf4..9c2363b54 100644
--- a/colossalai/autochunk/utils.py
+++ b/colossalai/autochunk/utils.py
@@ -3,14 +3,14 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
 from torch.fx.node import Node
 
 
-def unflat_list(inputs):
+def flat_list(inputs):
     """
-    unflat a list by recursion
+    flat a list by recursion
     """
     res = []
     for i in inputs:
         if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
-            res.extend(unflat_list(i))
+            res.extend(flat_list(i))
         else:
             res.append(i)
     return res
@@ -27,8 +27,13 @@ def find_first_tensor_arg(node):
 
 
 def is_non_compute_node(node):
-    if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(
-            i in node.name for i in ["getitem", "getattr"]):
+    if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(i in node.name for i in ["getattr"]):
+        return True
+    if "getitem" in node.name:
+        node_args = flat_list(node.args[1:])
+        for node_arg in node_args:
+            if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
+                return False
         return True
     return False
 
@@ -40,15 +45,15 @@ def get_node_shape(node):
 
 
 def is_non_compute_node_except_placeholder(node):
-    if any(i in node.op for i in ["get_attr", "output"]) or any(i in node.name for i in ["getitem", "getattr"]):
-        return True
-    return False
+    if "placeholder" in node.op:
+        return False
+    return is_non_compute_node(node)
 
 
 def is_non_compute_node_except_placeholder_output(node):
-    if any(i in node.op for i in ["get_attr"]) or any(i in node.name for i in ["getitem", "getattr"]):
-        return True
-    return False
+    if "output" in node.op:
+        return False
+    return is_non_compute_node_except_placeholder(node)
 
 
 def find_idx_by_name(name, nodes_list):
diff --git a/tests/test_autochunk/test_evoformer_codegen.py b/tests/test_autochunk/test_evoformer_codegen.py
index 1273bf2fe..c5a893eda 100644
--- a/tests/test_autochunk/test_evoformer_codegen.py
+++ b/tests/test_autochunk/test_evoformer_codegen.py
@@ -27,18 +27,17 @@ if CODEGEN_AVAILABLE and is_compatible_with_meta():
 
 def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):
     # for memory test
+    # model = model.cuda()
     # torch.cuda.reset_peak_memory_stats()
     # now_mem = torch.cuda.memory_allocated() / 1024**2
     # with torch.no_grad():
     #     node1 = node.clone()
     #     pair1 = pair.clone()
-    #     gm(node1, pair1)
-    # new_now_mem = torch.cuda.memory_allocated() / 1024**2
+    #     node_mask1 = node_mask.clone()
+    #     pair_mask1 = pair_mask.clone()
+    #     gm(node1, pair1, node_mask1, pair_mask1)
     # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
-    # print(
-    #     "autochunk now mem:%.2f max mem:%.2f"
-    #     % (new_now_mem - now_mem, new_max_mem - now_mem)
-    # )
+    # print("autochunk max mem:%.2f"% (new_max_mem - now_mem))
 
     # test forward
     model = model.cuda()
@@ -113,7 +112,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
         MetaTensor(node_mask, fake_device="cuda:0"),
         MetaTensor(pair_mask, fake_device="cuda:0"),
     )
-    # codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
+    codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory, print_mem=False)
 
     # trace and recompile
     # MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
@@ -130,14 +129,14 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
             "_mask_trans": True,
         },
     )
-    # graph.set_codegen(codegen)
+    graph.set_codegen(codegen)
     gm = ColoGraphModule(model, graph)
     gm.recompile()
 
     # assert we have inserted chunk
     code = graph.python_code("self").src
-    assert "chunk_size" in code
     # print(code)
+    assert "chunk_result = None;  chunk_size = None;" in code
 
     _test_fwd(model, gm, node, pair, node_mask, pair_mask)
     gpc.destroy()
@@ -147,7 +146,7 @@ def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
     not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
     reason="torch version is lower than 1.12.0",
 )
-@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
+@pytest.mark.parametrize("max_memory", [None, 24, 28, 32])
 @pytest.mark.parametrize("msa_len", [32])
 @pytest.mark.parametrize("pair_len", [64])
 def test_evoformer_codegen(msa_len, pair_len, max_memory):
@@ -161,4 +160,4 @@ def test_evoformer_codegen(msa_len, pair_len, max_memory):
 
 
 if __name__ == "__main__":
-    _test_evoformer_codegen(0, 32, 64, 25)
+    _test_evoformer_codegen(0, 32, 64, 24)
diff --git a/tests/test_autochunk/test_simple_evoformer_codegen.py b/tests/test_autochunk/test_simple_evoformer_codegen.py
index f1272330f..8ab77024c 100644
--- a/tests/test_autochunk/test_simple_evoformer_codegen.py
+++ b/tests/test_autochunk/test_simple_evoformer_codegen.py
@@ -13,7 +13,7 @@ except:
 
 import colossalai
 from colossalai.core import global_context as gpc
-from colossalai.fx import ColoTracer
+from colossalai.fx import ColoTracer, symbolic_trace
 from colossalai.fx._compatibility import is_compatible_with_meta
 from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
 from colossalai.fx.graph_module import ColoGraphModule
@@ -26,21 +26,6 @@ if CODEGEN_AVAILABLE and is_compatible_with_meta():
 
 
 def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
-    # for memory test
-    # torch.cuda.reset_peak_memory_stats()
-    # now_mem = torch.cuda.memory_allocated() / 1024**2
-    # with torch.no_grad():
-    #     node1 = node.clone()
-    #     pair1 = pair.clone()
-    #     gm(node1, pair1)
-    # new_now_mem = torch.cuda.memory_allocated() / 1024**2
-    # new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
-    # print(
-    #     "autochunk now mem:%.2f max mem:%.2f"
-    #     % (new_now_mem - now_mem, new_max_mem - now_mem)
-    # )
-
-    # test forward
     with torch.no_grad():
         non_fx_out = model(node, pair)
         fx_out = gm(node, pair)
@@ -69,6 +54,16 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
     node = torch.randn(1, msa_len, pair_len, 256).cuda()
     pair = torch.randn(1, pair_len, pair_len, 128).cuda()
 
+    # meta info prop
+    meta_graph = symbolic_trace(model,
+                                meta_args={
+                                    "node": node.to(torch.device("meta")),
+                                    "pair": pair.to(torch.device("meta")),
+                                })    # must use symbolic_trace
+    interp = MetaInfoProp(meta_graph)
+    interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
+    codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
+
     # trace the module and replace codegen
     graph = ColoTracer().trace(
         model,
@@ -77,24 +72,14 @@ def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
             "pair": pair.to(torch.device("meta")),
         },
     )
-    gm_prop = torch.fx.symbolic_trace(model)    # must use symbolic_trace
-    interp = MetaInfoProp(gm_prop)
-    interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
-
-    # now run it twice to get meta info in graph module, not necessary
-    gm = torch.fx.GraphModule(model, graph)
-    interp = MetaInfoProp(gm)
-    interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
-
-    codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory)
     graph.set_codegen(codegen)
     gm = ColoGraphModule(model, graph)
     gm.recompile()
 
     # assert we have inserted chunk
     code = graph.python_code("self").src
-    assert "chunk_size" in code
     # print(code)
+    assert "chunk_result = None;  chunk_size = None;" in code
 
     _test_fwd(model, gm, node, pair)
     gpc.destroy()
diff --git a/tests/test_autochunk/test_simple_evoformer_search.py b/tests/test_autochunk/test_simple_evoformer_search.py
index 04fb514fb..4c591c483 100644
--- a/tests/test_autochunk/test_simple_evoformer_search.py
+++ b/tests/test_autochunk/test_simple_evoformer_search.py
@@ -47,18 +47,18 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
             str(target_regions),
         )
     for region in target_regions:
-        assert (region in found_regions), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % (
+        assert (region in found_regions), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%s" % (
             str(region),
             msa_len,
             pair_len,
-            max_memory,
+            str(max_memory),
         )
     for region in found_regions:
         assert (region in target_regions), "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % (
             str(region),
             msa_len,
             pair_len,
-            max_memory,
+            str(max_memory),
         )