seperate input node dim search

pull/2364/head
oahzxl 2023-01-06 15:36:17 +08:00
parent ae27a8b26d
commit f4a1607e56
1 changed files with 21 additions and 14 deletions

View File

@ -812,19 +812,7 @@ class IndexTracer(object):
cur_node_list = next_node_list
return all_node_info
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = find_chunk_compute_input_and_output_nodes(
self.node_list[start_idx : end_idx + 1]
)
# only single ouput
if len(outputs) > 1:
return None
# get every node's chunk dim and fix dim
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
if all_node_info is None:
return None
def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info):
inputs_dim = []
remove_inputs = []
for input_node in inputs:
@ -841,7 +829,7 @@ class IndexTracer(object):
if input_node_idx in user_source:
input_dict[user_idx] = user_source[input_node_idx]
else:
return None
return None, None
if len(input_dict) == 0:
remove_inputs.append(input_node)
else:
@ -849,6 +837,25 @@ class IndexTracer(object):
for i in remove_inputs:
if i in inputs:
inputs.remove(i)
return inputs, inputs_dim
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = find_chunk_compute_input_and_output_nodes(
self.node_list[start_idx : end_idx + 1]
)
# only single ouput
if len(outputs) > 1:
return None
# get every node's chunk dim and fix dim
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
if all_node_info is None:
return None
# get input nodes' chunk dim
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
if inputs is None:
return None
chunk_info = {
"region": (start_idx, end_idx),