support check_index_duplicate

pull/2364/head
oahzxl 2022-12-13 10:01:30 +08:00
parent 8754fa2553
commit 1e0fd11bc1
1 changed files with 16 additions and 15 deletions

View File

@ -179,7 +179,12 @@ class FlowTracer(object):
"outputs_dim": end_dim,
"args": {},
}
flow_flag = False
flow_block = False
# TODO don't allow multi outputs now
if len(outputs) > 1:
flow_block = True
return flow_block, chunk_info
for idx in range(start_idx, end_idx + 1):
node = self.node_list[idx]
@ -199,10 +204,10 @@ class FlowTracer(object):
self.node_list[end_idx], end_dim, node
)
if mix_flow_node_dim is None:
flow_flag = True
flow_block = True
break
if _get_node_shape(mix_flow_node)[mix_flow_node_dim] == 1:
flow_flag = False
flow_block = False
for i in self._get_same_flow_node(
chunk_info["inputs"], mix_flow_node
):
@ -210,11 +215,15 @@ class FlowTracer(object):
# else, we need to chunk mix var as well
else:
# TODO chunk another value
flow_flag = True
flow_block = True
break
else:
raise NotImplementedError("%s not implemented" % node.name)
if flow_block:
flow_block = True
return flow_block, chunk_info
inputs_dim = []
remove_inputs = []
for input_node in chunk_info["inputs"]:
@ -250,7 +259,7 @@ class FlowTracer(object):
if i not in chunk_info["inputs"]:
chunk_info["inputs_non_chunk"].append(i)
return flow_flag, chunk_info
return flow_block, chunk_info
class IndexTracer(object):
@ -869,14 +878,6 @@ class IndexTracer(object):
if any(start_idx <= i <= end_idx for i in end_node_compute):
return False
return True
# end_node_trace_source = end_node_trace['source'][end_dim]
# for node_idx, node_dim in end_node_trace_source.items():
# if node_idx < start_node_idx or node_idx > end_node_idx:
# continue
# compute_list = self.idx_trace_list[node_idx]['compute'][node_dim]
# if any(start_node_idx <= i <= end_node_idx for i in compute_list):
# return False
# return True
def _get_node_chunk_dim(self, node_from, node_from_dim, node_to):
node_from_source = self._find_source_trace_from_node(node_from)
@ -1240,10 +1241,10 @@ class ChunkRegionSearch(object):
):
continue
# detect flow meet
flow_flag, chunk_info = self.flow_tracer._detect_flow(
flow_block, chunk_info = self.flow_tracer._detect_flow(
start_idx, start_dim, end_idx, end_dim, self.index_tracer
)
if flow_flag:
if flow_block:
continue
chunk_infos.append(chunk_info)
chunk_infos = self._check_duplicate_map(chunk_infos)