pull/2364/head
oahzxl 2 years ago
parent 774d34f1aa
commit 522f017418

@ -1004,7 +1004,7 @@ class FlowTracer(object):
# if already in node_info, arg dim must be same
if arg_node in all_node_info:
if all_node_info[arg_node]['chunk_dim'] != arg_dim:
if all_node_info[arg_node]["chunk_dim"] != arg_dim:
return False
all_node_info[arg_node]["fix_dim"] = list(
set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)
@ -1132,16 +1132,19 @@ class FlowTracer(object):
# get all possible prepose nodes
maybe_prepose_nodes = []
for node, node_info in all_node_info.items():
if node_info['chunk_dim'] is None:
if node_info["chunk_dim"] is None:
maybe_prepose_nodes.append(node)
maybe_prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list), reverse=True) # from last node to first node
maybe_prepose_nodes.sort(
key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list),
reverse=True,
) # from last node to first node
prepose_nodes = []
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
while len(maybe_prepose_nodes) > 0:
tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]]
tmp_cur_related_prepose_nodes = []
prepose_flag = True
# loop cur node's all arg until out of chunk
while len(tmp_cur_prepose_nodes) > 0:
tmp_next_prepose_nodes = []
@ -1151,20 +1154,28 @@ class FlowTracer(object):
if type(cur_prepose_node_arg) != type(cur_prepose_node):
continue
# out of loop
if not (start_idx <= _find_idx_by_name(cur_prepose_node_arg.name, self.node_list) < end_idx):
if not (
start_idx
<= _find_idx_by_name(
cur_prepose_node_arg.name, self.node_list
)
< end_idx
):
continue
# compute op in loop
elif cur_prepose_node_arg in all_node_info:
if all_node_info[cur_prepose_node_arg]['chunk_dim'] is None:
if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None:
tmp_next_prepose_nodes.append(cur_prepose_node_arg)
else:
prepose_flag = False
break; break; break
break
break
break
# non compute op
else:
tmp_next_prepose_nodes.append(cur_prepose_node_arg)
tmp_cur_prepose_nodes = tmp_next_prepose_nodes
if prepose_flag == False:
maybe_prepose_nodes.remove(maybe_prepose_nodes[0])
continue
@ -1175,21 +1186,21 @@ class FlowTracer(object):
if n in maybe_prepose_nodes:
maybe_prepose_nodes.remove(n)
# sort by index
prepose_nodes.sort(key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list))
prepose_nodes.sort(
key=lambda x: _find_idx_by_name(x.name, index_tracer.nodes_list)
)
chunk_info["args"]["prepose_nodes"] = prepose_nodes
# we need to log input nodes to avoid deleteing them in the loop
chunk_node_list = self.node_list[start_idx : end_idx + 1]
# also need to get some prepose node's arg out of non_chunk_inputs
for n in prepose_nodes:
chunk_node_list.remove(n)
non_chunk_inputs = _find_chunk_all_input_nodes(
chunk_node_list
)
non_chunk_inputs = _find_chunk_all_input_nodes(chunk_node_list)
for i in non_chunk_inputs:
if i not in chunk_info["inputs"] and i not in prepose_nodes:
chunk_info["inputs_non_chunk"].append(i)
return chunk_info

Loading…
Cancel
Save