|
|
@ -839,36 +839,7 @@ class IndexTracer(object): |
|
|
|
inputs.remove(i) |
|
|
|
inputs.remove(i) |
|
|
|
return inputs, inputs_dim |
|
|
|
return inputs, inputs_dim |
|
|
|
|
|
|
|
|
|
|
|
def flow_search(self, start_idx, start_dim, end_idx, end_dim): |
|
|
|
def _set_prepose_nodes(self, all_node_info, start_idx, end_idx): |
|
|
|
inputs, outputs = find_chunk_compute_input_and_output_nodes( |
|
|
|
|
|
|
|
self.node_list[start_idx : end_idx + 1] |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
# only single ouput |
|
|
|
|
|
|
|
if len(outputs) > 1: |
|
|
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# get every node's chunk dim and fix dim |
|
|
|
|
|
|
|
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) |
|
|
|
|
|
|
|
if all_node_info is None: |
|
|
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# get input nodes' chunk dim |
|
|
|
|
|
|
|
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info) |
|
|
|
|
|
|
|
if inputs is None: |
|
|
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chunk_info = { |
|
|
|
|
|
|
|
"region": (start_idx, end_idx), |
|
|
|
|
|
|
|
"inputs": inputs, |
|
|
|
|
|
|
|
"inputs_non_chunk": [], |
|
|
|
|
|
|
|
"inputs_dim": inputs_dim, |
|
|
|
|
|
|
|
"outputs": outputs, |
|
|
|
|
|
|
|
"outputs_dim": end_dim, |
|
|
|
|
|
|
|
"node_chunk_dim": all_node_info, |
|
|
|
|
|
|
|
"args": {}, |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# move useless nodes ahead of loop |
|
|
|
|
|
|
|
# get all possible prepose nodes |
|
|
|
# get all possible prepose nodes |
|
|
|
maybe_prepose_nodes = [] |
|
|
|
maybe_prepose_nodes = [] |
|
|
|
for node, node_info in all_node_info.items(): |
|
|
|
for node, node_info in all_node_info.items(): |
|
|
@ -929,12 +900,45 @@ class IndexTracer(object): |
|
|
|
maybe_prepose_nodes.remove(n) |
|
|
|
maybe_prepose_nodes.remove(n) |
|
|
|
# sort by index |
|
|
|
# sort by index |
|
|
|
prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.node_list)) |
|
|
|
prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.node_list)) |
|
|
|
chunk_info["args"]["prepose_nodes"] = prepose_nodes |
|
|
|
|
|
|
|
|
|
|
|
return prepose_nodes |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def flow_search(self, start_idx, start_dim, end_idx, end_dim): |
|
|
|
|
|
|
|
inputs, outputs = find_chunk_compute_input_and_output_nodes( |
|
|
|
|
|
|
|
self.node_list[start_idx : end_idx + 1] |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
# only single ouput |
|
|
|
|
|
|
|
if len(outputs) > 1: |
|
|
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# get every node's chunk dim and fix dim |
|
|
|
|
|
|
|
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx) |
|
|
|
|
|
|
|
if all_node_info is None: |
|
|
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# get input nodes' chunk dim |
|
|
|
|
|
|
|
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info) |
|
|
|
|
|
|
|
if inputs is None: |
|
|
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chunk_info = { |
|
|
|
|
|
|
|
"region": (start_idx, end_idx), |
|
|
|
|
|
|
|
"inputs": inputs, |
|
|
|
|
|
|
|
"inputs_non_chunk": [], |
|
|
|
|
|
|
|
"inputs_dim": inputs_dim, |
|
|
|
|
|
|
|
"outputs": outputs, |
|
|
|
|
|
|
|
"outputs_dim": end_dim, |
|
|
|
|
|
|
|
"node_chunk_dim": all_node_info, |
|
|
|
|
|
|
|
"args": {}, |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# move useless nodes ahead of loop |
|
|
|
|
|
|
|
chunk_info["args"]["prepose_nodes"] = self._set_prepose_nodes(all_node_info, start_idx, end_idx) |
|
|
|
|
|
|
|
|
|
|
|
# we need to log input nodes to avoid deleteing them in the loop |
|
|
|
# we need to log input nodes to avoid deleteing them in the loop |
|
|
|
chunk_node_list = self.node_list[start_idx : end_idx + 1] |
|
|
|
chunk_node_list = self.node_list[start_idx : end_idx + 1] |
|
|
|
# also need to get some prepose node's arg out of non_chunk_inputs |
|
|
|
# also need to get some prepose node's arg out of non_chunk_inputs |
|
|
|
for n in prepose_nodes: |
|
|
|
for n in chunk_info["args"]["prepose_nodes"]: |
|
|
|
chunk_node_list.remove(n) |
|
|
|
chunk_node_list.remove(n) |
|
|
|
non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list) |
|
|
|
non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list) |
|
|
|
for i in non_chunk_inputs: |
|
|
|
for i in non_chunk_inputs: |
|
|
|