|
|
|
@ -17,7 +17,6 @@ class IndexTracer(object):
|
|
|
|
|
self.idx_trace_equal = [] |
|
|
|
|
self.idx_view_list = {} |
|
|
|
|
self.idx_count = -1 |
|
|
|
|
self.all_reorder_map = {i: i for i in range(len(self.idx_trace_list))} |
|
|
|
|
|
|
|
|
|
def _init_idx_trace_list(self): |
|
|
|
|
idx_trace_list = [] |
|
|
|
@ -981,24 +980,30 @@ class IndexTracer(object):
|
|
|
|
|
chunk_info["reshape_size"] = reshape_size |
|
|
|
|
return chunk_info |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReorderGraph(object): |
|
|
|
|
def __init__(self, index_tracer: IndexTracer) -> None: |
|
|
|
|
self.index_tracer = index_tracer |
|
|
|
|
self.all_reorder_map = {i: i for i in range(len(self.index_tracer.idx_trace_list))} |
|
|
|
|
|
|
|
|
|
def _get_reorder_map(self, chunk_info): |
|
|
|
|
reorder_map = {i: i for i in range(len(self.node_list))} |
|
|
|
|
reorder_map = {i: i for i in range(len(self.index_tracer.node_list))} |
|
|
|
|
|
|
|
|
|
chunk_region_start = chunk_info["region"][0] |
|
|
|
|
chunk_region_end = chunk_info["region"][1] |
|
|
|
|
chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] |
|
|
|
|
chunk_prepose_nodes_idx = [ |
|
|
|
|
find_idx_by_name(i.name, self.node_list) for i in chunk_prepose_nodes |
|
|
|
|
find_idx_by_name(i.name, self.index_tracer.node_list) for i in chunk_prepose_nodes |
|
|
|
|
] |
|
|
|
|
# put prepose nodes ahead |
|
|
|
|
for idx, n in enumerate(chunk_prepose_nodes): |
|
|
|
|
n_idx = chunk_prepose_nodes_idx[idx] |
|
|
|
|
reorder_map[n_idx] = chunk_region_start + idx |
|
|
|
|
# put other nodes after prepose nodes |
|
|
|
|
for n in self.node_list[chunk_region_start : chunk_region_end + 1]: |
|
|
|
|
for n in self.index_tracer.node_list[chunk_region_start : chunk_region_end + 1]: |
|
|
|
|
if n in chunk_prepose_nodes: |
|
|
|
|
continue |
|
|
|
|
n_idx = find_idx_by_name(n.name, self.node_list) |
|
|
|
|
n_idx = find_idx_by_name(n.name, self.index_tracer.node_list) |
|
|
|
|
pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) |
|
|
|
|
reorder_map[n_idx] = n_idx + pos |
|
|
|
|
|
|
|
|
@ -1024,25 +1029,25 @@ class IndexTracer(object):
|
|
|
|
|
self.all_reorder_map[origin_idx] = reorder_map[map_idx] |
|
|
|
|
|
|
|
|
|
def _reorder_self_node_list(self, reorder_map): |
|
|
|
|
new_node_list = [None for _ in range(len(self.node_list))] |
|
|
|
|
new_node_list = [None for _ in range(len(self.index_tracer.node_list))] |
|
|
|
|
for old_idx, new_idx in reorder_map.items(): |
|
|
|
|
new_node_list[new_idx] = self.node_list[old_idx] |
|
|
|
|
self.node_list = new_node_list |
|
|
|
|
new_node_list[new_idx] = self.index_tracer.node_list[old_idx] |
|
|
|
|
self.index_tracer.node_list = new_node_list |
|
|
|
|
|
|
|
|
|
def _reorder_idx_trace(self, reorder_map): |
|
|
|
|
# reorder list |
|
|
|
|
new_idx_trace_list = [None for _ in range(len(self.idx_trace_list))] |
|
|
|
|
new_idx_trace_list = [None for _ in range(len(self.index_tracer.idx_trace_list))] |
|
|
|
|
for old_idx, new_idx in reorder_map.items(): |
|
|
|
|
new_idx_trace_list[new_idx] = self.idx_trace_list[old_idx] |
|
|
|
|
self.idx_trace_list = new_idx_trace_list |
|
|
|
|
new_idx_trace_list[new_idx] = self.index_tracer.idx_trace_list[old_idx] |
|
|
|
|
self.index_tracer.idx_trace_list = new_idx_trace_list |
|
|
|
|
# update compute |
|
|
|
|
for idx_trace in self.idx_trace_list: |
|
|
|
|
for idx_trace in self.index_tracer.idx_trace_list: |
|
|
|
|
compute = idx_trace["compute"] |
|
|
|
|
for dim_compute in compute: |
|
|
|
|
for idx, i in enumerate(dim_compute): |
|
|
|
|
dim_compute[idx] = reorder_map[i] |
|
|
|
|
# update source |
|
|
|
|
for idx_trace in self.idx_trace_list: |
|
|
|
|
for idx_trace in self.index_tracer.idx_trace_list: |
|
|
|
|
source = idx_trace["source"] |
|
|
|
|
for dim_idx, dim_source in enumerate(source): |
|
|
|
|
new_dim_source = {} |
|
|
|
|