From 80efd70c725b00c236b80b68393c0d13ec457b0b Mon Sep 17 00:00:00 2001 From: oahzxl Date: Sat, 31 Dec 2022 13:44:46 +0800 Subject: [PATCH] improve reorder efficeincy --- chunk_codegen.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index e20d151da..7c334c617 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1486,6 +1486,8 @@ class ChunkSelector(object): "chunk_len": self._get_compute_node_num( region["region"][0], region["region"][1] ), + "reorder_chunk_info": cur_region, + "reorder_node_list": cur_node_list } ) # no region found @@ -1495,48 +1497,47 @@ class ChunkSelector(object): # select the min chunk len chunk_len = [i["chunk_len"] for i in regions_dict] best_region_idx = chunk_len.index(min(chunk_len)) - best_region = regions_dict[best_region_idx]["chunk_info"] + best_region = regions_dict[best_region_idx] # get max chunk size best_region = self._get_fit_chunk_size(best_region, chunk_infos) return best_region - def _get_fit_chunk_size(self, chunk_info, chunk_infos): + def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos): chunk_size = 1 - chunk_info["chunk_size"] = chunk_size + reorder_chunk_info = chunk_region_dict['reorder_chunk_info'] + reorder_chunk_info["chunk_size"] = chunk_size cur_chunk_max_mem = 0 # search a region while cur_chunk_max_mem < self.max_memory: chunk_size *= 2 - chunk_info["chunk_size"] = chunk_size - cur_chunk_info = chunk_info.copy() - cur_node_list, cur_chunk_info = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_chunk_info) - cur_chunk_infos = chunk_infos + [cur_chunk_info] + reorder_chunk_info["chunk_size"] = chunk_size + cur_chunk_infos = chunk_infos + [reorder_chunk_info] cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - cur_node_list, cur_chunk_infos + chunk_region_dict['reorder_node_list'], cur_chunk_infos )[0] cur_chunk_max_mem = max( - cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] + cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1] ) # search exact size + chunk_info = chunk_region_dict["chunk_info"] chunk_info["chunk_size"] = self._chunk_size_binary_search( - chunk_size // 2, chunk_size, chunk_info, chunk_infos + chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos ) return chunk_info - def _chunk_size_binary_search(self, l, r, chunk_info, chunk_infos): + def _chunk_size_binary_search(self, l, r, chunk_region_dict, chunk_infos): if l >= 16: gap = 4 else: gap = 1 + chunk_info = chunk_region_dict['reorder_chunk_info'] while r >= l + gap: mid = int((l + r) / 2 + 0.5) chunk_info["chunk_size"] = mid - cur_chunk_info = chunk_info.copy() - cur_node_list, cur_chunk_info = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_chunk_info) - cur_chunk_infos = chunk_infos + [cur_chunk_info] + cur_chunk_infos = chunk_infos + [chunk_info] cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - cur_node_list, cur_chunk_infos + chunk_region_dict['reorder_node_list'], cur_chunk_infos )[0] cur_chunk_max_mem = max( cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] @@ -1904,7 +1905,7 @@ def _find_idx_by_name(name, nodes_list): def _replace_name(context, name_from, name_to): - patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ",")] + patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")")] for p in patterns: source = p[0] + name_from + p[1] target = p[0] + name_to + p[1]