seperate non chunk input

pull/2364/head
oahzxl 2023-01-06 15:53:24 +08:00
parent f856611d21
commit 6685a9d022
1 changed files with 22 additions and 13 deletions

View File

@ -839,7 +839,7 @@ class IndexTracer(object):
inputs.remove(i) inputs.remove(i)
return inputs, inputs_dim return inputs, inputs_dim
def _set_prepose_nodes(self, all_node_info, start_idx, end_idx): def _get_prepose_nodes(self, all_node_info, start_idx, end_idx):
# get all possible prepose nodes # get all possible prepose nodes
maybe_prepose_nodes = [] maybe_prepose_nodes = []
for node, node_info in all_node_info.items(): for node, node_info in all_node_info.items():
@ -903,6 +903,18 @@ class IndexTracer(object):
return prepose_nodes return prepose_nodes
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
# we need to log input nodes to avoid deleteing them in the loop
chunk_node_list = self.node_list[start_idx : end_idx + 1]
# also need to get some prepose node's arg out of non_chunk_inputs
for n in chunk_info["args"]["prepose_nodes"]:
chunk_node_list.remove(n)
non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list)
for i in non_chunk_inputs:
if i not in chunk_info["inputs"]:
chunk_info["inputs_non_chunk"].append(i)
return chunk_info
def flow_search(self, start_idx, start_dim, end_idx, end_dim): def flow_search(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = find_chunk_compute_input_and_output_nodes( inputs, outputs = find_chunk_compute_input_and_output_nodes(
self.node_list[start_idx : end_idx + 1] self.node_list[start_idx : end_idx + 1]
@ -917,7 +929,9 @@ class IndexTracer(object):
return None return None
# get input nodes' chunk dim # get input nodes' chunk dim
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info) inputs, inputs_dim = self._get_input_nodes_dim(
inputs, start_idx, end_idx, all_node_info
)
if inputs is None: if inputs is None:
return None return None
@ -933,17 +947,12 @@ class IndexTracer(object):
} }
# move useless nodes ahead of loop # move useless nodes ahead of loop
chunk_info["args"]["prepose_nodes"] = self._set_prepose_nodes(all_node_info, start_idx, end_idx) chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(
all_node_info, start_idx, end_idx
)
# we need to log input nodes to avoid deleteing them in the loop # find non chunk inputs
chunk_node_list = self.node_list[start_idx : end_idx + 1] chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
# also need to get some prepose node's arg out of non_chunk_inputs
for n in chunk_info["args"]["prepose_nodes"]:
chunk_node_list.remove(n)
non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list)
for i in non_chunk_inputs:
if i not in chunk_info["inputs"]:
chunk_info["inputs_non_chunk"].append(i)
# reassgin reshape size, some size may have changed due to chunk # reassgin reshape size, some size may have changed due to chunk
chunk_info = self._reassgin_reshape_size(chunk_info) chunk_info = self._reassgin_reshape_size(chunk_info)