refactor flow search

pull/2364/head
oahzxl 2 years ago
parent ded1005667
commit 774d34f1aa

@ -1004,7 +1004,7 @@ class FlowTracer(object):
# if already in node_info, arg dim must be same
if arg_node in all_node_info:
if all_node_info[arg_node] != arg_dim:
if all_node_info[arg_node]['chunk_dim'] != arg_dim:
return False
all_node_info[arg_node]["fix_dim"] = list(
set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)
@ -1128,14 +1128,68 @@ class FlowTracer(object):
"args": {},
}
# move useless nodes ahead of loop
# get all possible prepose nodes
maybe_prepose_nodes = []
for node, node_info in all_node_info.items():
if node_info['chunk_dim'] is None:
maybe_prepose_nodes.append(node)
maybe_prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list), reverse=True) # from last node to first node
prepose_nodes = []
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
while len(maybe_prepose_nodes) > 0:
tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]]
tmp_cur_related_prepose_nodes = []
prepose_flag = True
# loop cur node's all arg until out of chunk
while len(tmp_cur_prepose_nodes) > 0:
tmp_next_prepose_nodes = []
tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes)
for cur_prepose_node in tmp_cur_prepose_nodes:
for cur_prepose_node_arg in cur_prepose_node.args:
if type(cur_prepose_node_arg) != type(cur_prepose_node):
continue
# out of loop
if not (start_idx <= _find_idx_by_name(cur_prepose_node_arg.name, self.node_list) < end_idx):
continue
# compute op in loop
elif cur_prepose_node_arg in all_node_info:
if all_node_info[cur_prepose_node_arg]['chunk_dim'] is None:
tmp_next_prepose_nodes.append(cur_prepose_node_arg)
else:
prepose_flag = False
break; break; break
# non compute op
else:
tmp_next_prepose_nodes.append(cur_prepose_node_arg)
tmp_cur_prepose_nodes = tmp_next_prepose_nodes
if prepose_flag == False:
maybe_prepose_nodes.remove(maybe_prepose_nodes[0])
continue
else:
for n in tmp_cur_related_prepose_nodes:
if n not in prepose_nodes:
prepose_nodes.append(n)
if n in maybe_prepose_nodes:
maybe_prepose_nodes.remove(n)
# sort by index
prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list))
chunk_info["args"]["prepose_nodes"] = prepose_nodes
# we need to log input nodes to avoid deleteing them in the loop
chunk_node_list = self.node_list[start_idx : end_idx + 1]
# also need to get some prepose node's arg out of non_chunk_inputs
for n in prepose_nodes:
chunk_node_list.remove(n)
non_chunk_inputs = _find_chunk_all_input_nodes(
self.node_list[start_idx : end_idx + 1]
chunk_node_list
)
for i in non_chunk_inputs:
if i not in chunk_info["inputs"]:
if i not in chunk_info["inputs"] and i not in prepose_nodes:
chunk_info["inputs_non_chunk"].append(i)
return chunk_info
@ -1541,16 +1595,6 @@ class ChunkRegionSearch(object):
continue
for start_node, start_trace in start_traces.items():
for start_dim, start_trace_idx in enumerate(start_trace["idx"]):
if (
start_idx == 199
and end_idx == 229
and start_dim == 2
and end_dim == 2
):
print(1)
self.flow_tracer.flow_search(
start_idx, start_dim, end_idx, end_dim, self.index_tracer
)
# dim size cannot be 1
if (
_get_node_shape(end_node)[end_dim] == 1
@ -1567,12 +1611,6 @@ class ChunkRegionSearch(object):
start_idx, end_dim, end_node, end_idx
):
continue
# detect flow meet
# flow_block, chunk_info = self.flow_tracer._detect_flow(
# start_idx, start_dim, end_idx, end_dim, self.index_tracer
# )
# if flow_block:
# continue
# flow search
chunk_info = self.flow_tracer.flow_search(
start_idx, start_dim, end_idx, end_dim, self.index_tracer

Loading…
Cancel
Save