From 966e4ea0cbf1cd17696aa90b6b9bd4a6999cfba4 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Sat, 31 Dec 2022 02:20:07 +0800 Subject: [PATCH] add reorder in mem estimator --- chunk_codegen.py | 43 ++++++++++++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 11 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index de58a61b9..e20d151da 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1040,11 +1040,13 @@ class IndexTracer(object): chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]), chunk_info["region"][1], ) + new_inputs_dim = [] for idx, input_dim in enumerate(chunk_info["inputs_dim"]): new_input_dim = {} for k, v in input_dim.items(): new_input_dim[reorder_map[k]] = v - chunk_info["inputs_dim"][idx] = new_input_dim + new_inputs_dim.append(new_input_dim) + chunk_info["inputs_dim"] = new_inputs_dim return chunk_info def _update_all_reorder_map(self, reorder_map): @@ -1095,11 +1097,24 @@ class IndexTracer(object): for old_idx, new_idx in self.all_reorder_map.items(): new_node_list[new_idx] = node_list[old_idx] return new_node_list + + def tmp_reorder(self, node_list, chunk_info): + if len(chunk_info["args"]["prepose_nodes"]) == 0: + return node_list, chunk_info + reorder_map = self._get_reorder_map(chunk_info) + + # new tmp node list + new_node_list = [None for _ in range(len(node_list))] + for old_idx, new_idx in reorder_map.items(): + new_node_list[new_idx] = node_list[old_idx] + + chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) + return new_node_list, chunk_info class MemoryEstimator(object): def __init__(self, index_tracer: IndexTracer) -> None: - self.index_tracer = index_tracer + pass def _get_meta_node_size(self, x): x = x.meta["tensor_meta"] @@ -1453,9 +1468,11 @@ class ChunkSelector(object): # get mem for chunk region regions_dict = [] for region in possible_chunk_regions: - cur_chunk_infos = chunk_infos + [region] + cur_region = region.copy() + cur_node_list, cur_region = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_region) + cur_chunk_infos = chunk_infos + [cur_region] cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list, cur_chunk_infos + cur_node_list, cur_chunk_infos )[0] cur_chunk_region_peak = cur_mem_peak[ max_chunk_region[0] : max_chunk_region[1] + 1 @@ -1492,9 +1509,11 @@ class ChunkSelector(object): while cur_chunk_max_mem < self.max_memory: chunk_size *= 2 chunk_info["chunk_size"] = chunk_size - cur_chunk_infos = chunk_infos + [chunk_info] + cur_chunk_info = chunk_info.copy() + cur_node_list, cur_chunk_info = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_chunk_info) + cur_chunk_infos = chunk_infos + [cur_chunk_info] cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list, cur_chunk_infos + cur_node_list, cur_chunk_infos )[0] cur_chunk_max_mem = max( cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] @@ -1511,11 +1530,13 @@ class ChunkSelector(object): else: gap = 1 while r >= l + gap: - mid = int(l + (r - l) / 2) + mid = int((l + r) / 2 + 0.5) chunk_info["chunk_size"] = mid - cur_chunk_infos = chunk_infos + [chunk_info] + cur_chunk_info = chunk_info.copy() + cur_node_list, cur_chunk_info = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_chunk_info) + cur_chunk_infos = chunk_infos + [cur_chunk_info] cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - self.index_tracer.node_list, cur_chunk_infos + cur_node_list, cur_chunk_infos )[0] cur_chunk_max_mem = max( cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] @@ -1529,7 +1550,7 @@ class ChunkSelector(object): def _get_compute_node_num(self, start, end): count = 0 for i in self.index_tracer.node_list[start : end + 1]: - if _is_non_compute_node(i): + if not _is_non_compute_node(i): count += 1 return count @@ -1547,7 +1568,7 @@ class ChunkSelector(object): max_region_range = 0 best_region = None if best_region is not None: - best_region["chunk_size"] = 2 + best_region["chunk_size"] = 1 return best_region def _is_legal_region(self, cur_chunk_info, chunk_infos):