seperate flow tracer

pull/2364/head
oahzxl 2023-01-06 14:57:33 +08:00
parent fd87d78a28
commit ae27a8b26d
1 changed files with 15 additions and 9 deletions

View File

@ -745,14 +745,7 @@ class IndexTracer(object):
next_node_list.append(arg_node)
return True
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = find_chunk_compute_input_and_output_nodes(
self.node_list[start_idx : end_idx + 1]
)
# only single ouput
if len(outputs) > 1:
return None
def _get_all_node_info(self, end_dim, start_idx, end_idx):
cur_node_list = [self.node_list[end_idx]] # start from the last node
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
@ -763,7 +756,6 @@ class IndexTracer(object):
# get cur node info
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
cur_node_idx = find_idx_by_name(cur_node.name, self.node_list)
if cur_node_chunk_dim:
cur_node_compute = self._find_compute_trace_from_node(cur_node)
cur_node_source = self._find_source_trace_from_node(cur_node)
@ -818,6 +810,20 @@ class IndexTracer(object):
else:
raise NotImplementedError()
cur_node_list = next_node_list
return all_node_info
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
inputs, outputs = find_chunk_compute_input_and_output_nodes(
self.node_list[start_idx : end_idx + 1]
)
# only single ouput
if len(outputs) > 1:
return None
# get every node's chunk dim and fix dim
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
if all_node_info is None:
return None
inputs_dim = []
remove_inputs = []