diff --git a/chunk_codegen.py b/chunk_codegen.py index 18d9a0c8d..5e2130ee7 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -958,6 +958,8 @@ class MemoryEstimator(object): def _add_active_node(self, n, active_list): new_active = self._get_output_node(n)[1] + if n.op == 'placeholder': + new_active.append(n.name) for i in new_active: if i not in active_list: active_list.append(i) @@ -965,7 +967,7 @@ class MemoryEstimator(object): def _get_delete_node(self, user, user_to_last_uses, to_keep=None): delete_size = 0 delete_node = [] - if user.op not in ("placeholder", "output"): + if user.op not in ("output",): nodes_to_delete = user_to_last_uses.get(user, []) if to_keep is not None: keep_list = [] @@ -1258,24 +1260,30 @@ class ChunkRegionSearch(object): def _search_max_chunk_region(self, active_node, peak_node, chunk_regions): free_vars = self._get_free_var() - min_var = self._get_min_free_var(active_node, free_vars) - + free_var_num = len(free_vars) + active_node_num = [len(i) for i in active_node] + min_active_node_num = min(active_node_num[free_var_num:]) + threshold = max(free_var_num, min_active_node_num) + # from peak_node to free_var - chunk_region_start = len(free_vars) + inside_flag = False + chunk_region_start = free_var_num for i in range(peak_node, -1, -1): - if len(active_node[i]) == min_var: + if active_node_num[i] <= threshold: + inside_flag = True + if inside_flag and active_node_num[i] > threshold: chunk_region_start = i + 1 break - if i in free_vars or i == 0: - raise RuntimeError() + # from peak_node to len-2 + inside_flag = False chunk_region_end = len(active_node) - 1 for i in range(peak_node, len(active_node)): - if len(active_node[i]) == min_var: + if active_node_num[i] <= threshold: + inside_flag = True + if inside_flag and active_node_num[i] > threshold: chunk_region_end = i break - if i in free_vars or i == 0: - raise RuntimeError() for i in chunk_regions: region = i["region"] @@ -1374,15 +1382,34 @@ class ChunkRegionSearch(object): possible_chunk_region.extend(chunk_info) return possible_chunk_region - def _search_best_chunk_region(self, possible_chunk_regions): + def _search_best_chunk_region(self, possible_chunk_regions, chunk_infos): max_region_range = 0 - best_regions = None - for i in possible_chunk_regions: - if i["region"][1] - i["region"][0] > max_region_range: - best_regions = i - max_region_range = i["region"][1] - i["region"][0] - return best_regions - + best_region = None + while len(possible_chunk_regions) > 0: + for i in possible_chunk_regions: + if i["region"][1] - i["region"][0] > max_region_range: + best_region = i + max_region_range = i["region"][1] - i["region"][0] + if self._is_legal_region(best_region, chunk_infos): + break + possible_chunk_regions.remove(i) + max_region_range = 0 + best_region = None + return best_region + + def _is_legal_region(self, cur_chunk_info, chunk_infos): + (chunk_region_start, chunk_region_end) = cur_chunk_info["region"] + if cur_chunk_info in chunk_infos: + return False + if chunk_region_end < chunk_region_start: + return False + for i in chunk_infos: + region = i["region"] + if not ((chunk_region_start > region[1] and chunk_region_end > region[1]) + or (chunk_region_start < region[0] and chunk_region_end < region[0])): + return False + return True + def _step_search(self, mem_peak, active_node, chunk_regions): peak_node = self._find_peak_node(mem_peak) max_chunk_region = self._search_max_chunk_region( @@ -1393,7 +1420,7 @@ class ChunkRegionSearch(object): possible_chunk_regions = self._search_possible_chunk_regions( max_chunk_region, peak_node ) - best_chunk_region = self._search_best_chunk_region(possible_chunk_regions) + best_chunk_region = self._search_best_chunk_region(possible_chunk_regions, chunk_regions) return best_chunk_region def _stop_search(self, init_mem_peak, mem_peak): @@ -1919,5 +1946,5 @@ if CODEGEN_AVAILABLE: {prologue} {code}""" - print(fn_code) + # print(fn_code) return PythonCode(fn_code, globals_)