From 884a228ea674b02998575776b0069b15de0b7a10 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Fri, 23 Dec 2022 17:06:07 +0800 Subject: [PATCH] reorder nodes --- chunk_codegen.py | 127 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 101 insertions(+), 26 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 6e772aa8a..4b3b04d93 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -71,6 +71,7 @@ class IndexTracer(object): self.idx_trace_equal = [] self.idx_view_list = [] self.idx_count = -1 + self.all_reorder_map = {i: i for i in range(len(self.idx_trace_list))} def _init_idx_trace_list(self): idx_trace_list = [] @@ -973,6 +974,91 @@ class IndexTracer(object): return chunk_info + def _get_reorder_map(self, chunk_info): + reorder_map = {i: i for i in range(len(self.node_list))} + + chunk_region_start = chunk_info["region"][0] + chunk_region_end = chunk_info["region"][1] + chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"] + chunk_prepose_nodes_idx = [ + _find_idx_by_name(i.name, self.node_list) for i in chunk_prepose_nodes + ] + # put prepose nodes ahead + for idx, n in enumerate(chunk_prepose_nodes): + n_idx = chunk_prepose_nodes_idx[idx] + reorder_map[n_idx] = chunk_region_start + idx + # put other nodes after prepose nodes + for n in self.node_list[chunk_region_start : chunk_region_end + 1]: + if n in chunk_prepose_nodes: + continue + n_idx = _find_idx_by_name(n.name, self.node_list) + pos = sum([n_idx < i for i in chunk_prepose_nodes_idx]) + reorder_map[n_idx] = n_idx + pos + + return reorder_map + + def _reorder_chunk_info(self, chunk_info, reorder_map): + # update chunk info + chunk_info["region"] = ( + chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]), + chunk_info["region"][1], + ) + for idx, input_dim in enumerate(chunk_info["inputs_dim"]): + new_input_dim = {} + for k, v in input_dim.items(): + new_input_dim[reorder_map[k]] = v + chunk_info["inputs_dim"][idx] = new_input_dim + return chunk_info + + def _update_all_reorder_map(self, reorder_map): + for origin_idx, map_idx in self.all_reorder_map.items(): + self.all_reorder_map[origin_idx] = reorder_map[map_idx] + + def _reorder_self_node_list(self, reorder_map): + new_node_list = [None for _ in range(len(self.node_list))] + for old_idx, new_idx in reorder_map.items(): + new_node_list[new_idx] = self.node_list[old_idx] + self.node_list = new_node_list + + def _reorder_idx_trace(self, reorder_map): + # reorder list + new_idx_trace_list = [None for _ in range(len(self.idx_trace_list))] + for old_idx, new_idx in reorder_map.items(): + new_idx_trace_list[new_idx] = self.idx_trace_list[old_idx] + self.idx_trace_list = new_idx_trace_list + # update compute + for idx_trace in self.idx_trace_list: + compute = idx_trace["compute"] + for dim_compute in compute: + for idx, i in enumerate(dim_compute): + dim_compute[idx] = reorder_map[i] + # update source + for idx_trace in self.idx_trace_list: + source = idx_trace["source"] + for dim_idx, dim_source in enumerate(source): + new_dim_source = {} + for k, v in dim_source.items(): + new_dim_source[reorder_map[k]] = v + source[dim_idx] = new_dim_source + + def reorder_all(self, chunk_info): + if chunk_info is None: + return chunk_info + if len(chunk_info["args"]["prepose_nodes"]) == 0: + return chunk_info + reorder_map = self._get_reorder_map(chunk_info) + self._update_all_reorder_map(reorder_map) + self._reorder_idx_trace(reorder_map) + self._reorder_self_node_list(reorder_map) + chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) + return chunk_info + + def reorder_node_list(self, node_list): + new_node_list = [None for _ in range(len(node_list))] + for old_idx, new_idx in self.all_reorder_map.items(): + new_node_list[new_idx] = node_list[old_idx] + return new_node_list + class MemoryEstimator(object): def __init__(self, index_tracer: IndexTracer) -> None: @@ -1476,6 +1562,7 @@ class ChunkRegionSearch(object): best_chunk_region = self._search_best_chunk_region( possible_chunk_regions, chunk_regions ) + best_chunk_region = self.index_tracer.reorder_all(best_chunk_region) return best_chunk_region def _stop_search(self, init_mem_peak, mem_peak): @@ -1670,8 +1757,7 @@ def emit_code_with_chunk( chunk_outputs = [i["outputs"][0] for i in chunk_search] chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search] - chunk_prepose_nodes = [i["args"]["prepose_nodes"] for i in chunk_search] - + node_list = chunk_region_search.index_tracer.reorder_node_list(node_list) node_idx = 0 region_idx = 0 within_chunk_region = False @@ -1682,12 +1768,6 @@ def emit_code_with_chunk( if node_idx in chunk_starts: within_chunk_region = True region_idx = chunk_starts.index(node_idx) - # add prepose nodes - for i in chunk_prepose_nodes[region_idx]: - prepose_node = node_list[_find_idx_by_name(i.name, node_list)] - emit_node_func(prepose_node, body) - delete_unused_value_func(prepose_node, body, chunk_inputs_names) - # add for loop body.append( _gen_loop_start( chunk_inputs[region_idx], @@ -1697,24 +1777,19 @@ def emit_code_with_chunk( ) if within_chunk_region: - if any(node.name == i.name for i in chunk_prepose_nodes[region_idx]): - pass - else: - emit_node_func(node, body) - # replace input var with chunk var - for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): - for idx, dim in chunk_inputs_dim[region_idx][ - input_node_idx - ].items(): - if idx == node_idx: - chunk_slice = _gen_chunk_slice_dim( - dim, "chunk_idx", _get_node_shape(input_node) - ) - body[-1] = _replace_name( - body[-1], input_node.name, input_node.name + chunk_slice - ) - body[-1] = " " + body[-1] - delete_unused_value_func(node, body, chunk_inputs_names) + emit_node_func(node, body) + # replace input var with chunk var + for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]): + for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items(): + if idx == node_idx: + chunk_slice = _gen_chunk_slice_dim( + dim, "chunk_idx", _get_node_shape(input_node) + ) + body[-1] = _replace_name( + body[-1], input_node.name, input_node.name + chunk_slice + ) + body[-1] = " " + body[-1] + delete_unused_value_func(node, body, chunk_inputs_names) else: emit_node_func(node, body) if node_idx not in chunk_inputs: