From fa5e6fbf96448ebff1dc682e749a3f73a5a9c2b5 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 23 Dec 2022 15:38:37 +0800 Subject: [PATCH] code style --- chunk_codegen.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 838f53949..e80b0fd9b 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -65,9 +65,8 @@ def _is_non_compute_node_except_placeholder_output(node): class IndexTracer(object): - def __init__(self, gm) -> None: - self.gm = gm - self.node_list = list(gm.graph.nodes) + def __init__(self, node_list) -> None: + self.node_list = node_list self.idx_trace_list = self._init_idx_trace_list() self.idx_trace_equal = [] self.idx_view_list = [] @@ -797,9 +796,7 @@ class IndexTracer(object): next_node_list.append(arg_node) return True - def flow_search( - self, start_idx, start_dim, end_idx, end_dim - ): + def flow_search(self, start_idx, start_dim, end_idx, end_dim): inputs, outputs = _find_chunk_compute_input_and_output_nodes( self.node_list[start_idx : end_idx + 1] ) @@ -819,12 +816,8 @@ class IndexTracer(object): cur_node_fix_dim = all_node_info[cur_node]["fix_dim"] cur_node_idx = _find_idx_by_name(cur_node.name, self.node_list) if cur_node_chunk_dim: - cur_node_compute = self._find_compute_trace_from_node( - cur_node - ) - cur_node_source = self._find_source_trace_from_node( - cur_node - ) + cur_node_compute = self._find_compute_trace_from_node(cur_node) + cur_node_source = self._find_source_trace_from_node(cur_node) else: cur_node_compute = cur_node_source = None @@ -965,9 +958,7 @@ class IndexTracer(object): if n in maybe_prepose_nodes: maybe_prepose_nodes.remove(n) # sort by index - prepose_nodes.sort( - key=lambda x: _find_idx_by_name(x.name, self.node_list) - ) + prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, self.node_list)) chunk_info["args"]["prepose_nodes"] = prepose_nodes # we need to log input nodes to avoid deleteing them in the loop @@ -1295,7 +1286,9 @@ class ChunkRegionSearch(object): def __init__(self, gm) -> None: self.gm = gm self.node_list = list(gm.graph.nodes) - self.index_tracer = IndexTracer(gm) + self.index_tracer = IndexTracer( + self.node_list + ) # node list shared in index tracer self.index_tracer.trace_index() self.memory_estimator = MemoryEstimator(self.index_tracer)