code style

pull/2364/head
oahzxl 2022-12-23 15:38:37 +08:00
parent 4f5e105af3
commit fa5e6fbf96
1 changed files with 9 additions and 16 deletions

View File

@ -65,9 +65,8 @@ def _is_non_compute_node_except_placeholder_output(node):
class IndexTracer(object):
def __init__(self, gm) -> None:
self.gm = gm
self.node_list = list(gm.graph.nodes)
def __init__(self, node_list) -> None:
self.node_list = node_list
self.idx_trace_list = self._init_idx_trace_list()
self.idx_trace_equal = []
self.idx_view_list = []
@ -797,9 +796,7 @@ class IndexTracer(object):
next_node_list.append(arg_node)
return True
def flow_search(
self, start_idx, start_dim, end_idx, end_dim
):
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = _find_chunk_compute_input_and_output_nodes(
self.node_list[start_idx : end_idx + 1]
)
@ -819,12 +816,8 @@ class IndexTracer(object):
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
cur_node_idx = _find_idx_by_name(cur_node.name, self.node_list)
if cur_node_chunk_dim:
cur_node_compute = self._find_compute_trace_from_node(
cur_node
)
cur_node_source = self._find_source_trace_from_node(
cur_node
)
cur_node_compute = self._find_compute_trace_from_node(cur_node)
cur_node_source = self._find_source_trace_from_node(cur_node)
else:
cur_node_compute = cur_node_source = None
@ -965,9 +958,7 @@ class IndexTracer(object):
if n in maybe_prepose_nodes:
maybe_prepose_nodes.remove(n)
# sort by index
prepose_nodes.sort(
key=lambda x: _find_idx_by_name(x.name, self.node_list)
)
prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, self.node_list))
chunk_info["args"]["prepose_nodes"] = prepose_nodes
# we need to log input nodes to avoid deleteing them in the loop
@ -1295,7 +1286,9 @@ class ChunkRegionSearch(object):
def __init__(self, gm) -> None:
self.gm = gm
self.node_list = list(gm.graph.nodes)
self.index_tracer = IndexTracer(gm)
self.index_tracer = IndexTracer(
self.node_list
) # node list shared in index tracer
self.index_tracer.trace_index()
self.memory_estimator = MemoryEstimator(self.index_tracer)