diff --git a/chunk_codegen.py b/chunk_codegen.py index b5bb8f185..79cefddf0 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -56,6 +56,14 @@ def _is_non_compute_node_except_placeholder(node): return False +def _is_non_compute_node_except_placeholder_output(node): + if any(i in node.op for i in ["get_attr"]) or any( + i in node.name for i in ["getitem", "getattr"] + ): + return True + return False + + class FlowTracer(object): def __init__(self, gm) -> None: self.gm = gm @@ -1083,13 +1091,14 @@ class MemoryEstimator(object): i is not None for i in [start_nodes, end_nodes, chunk_dims, chunk_sizes] ) chunk_within = False - chunk_region_idx = 0 + chunk_region_idx = None chunk_ratio = 1 # use it to estimate chunk mem for idx, node in enumerate(node_list): # if node in chunk start nodes, change chunk ratio and add chunk_tensor if use_chunk and idx in start_nodes: chunk_within = True + chunk_region_idx = start_nodes.index(idx) chunk_ratio = self._get_chunk_ratio( node, chunk_dims[chunk_region_idx], chunk_sizes[chunk_region_idx] ) @@ -1149,7 +1158,7 @@ class MemoryEstimator(object): ) chunk_within = False chunk_ratio = 1 - chunk_region_idx += 1 + chunk_region_idx = None act_memory_after_node_log.append(act_memory) active_node_list_log.append(copy.deepcopy(active_node_list)) @@ -1467,7 +1476,7 @@ def _find_chunk_compute_input_and_output_nodes(nodes: List[Node]): if ( output_node not in nodes and node not in output_nodes - and not _is_non_compute_node_except_placeholder(output_node) + and not _is_non_compute_node_except_placeholder_output(output_node) ): output_nodes.append(node)