diff --git a/chunk_codegen.py b/chunk_codegen.py index f87a3a132..cdd0b1077 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1368,12 +1368,60 @@ class MemoryEstimator(object): return act_memory_peak_log, act_memory_after_node_log, active_node_list_log +class ChunkSelector(object): + def __init__(self, index_tracer: IndexTracer, stratge) -> None: + self.index_tracer = index_tracer + assert stratge in ['min_memory', 'fit_memory'] + self.stratge = stratge + self.max_memory = 800 # MB + + def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos): + if self.stratge == 'min_memory': + best_region = self._select_min_memory_chunk_region(possible_chunk_regions, chunk_infos) + elif self.stratge == 'fit_memory': + pass + else: + raise RuntimeError() + return best_region + + def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos): + max_region_range = 0 + best_region = None + while len(possible_chunk_regions) > 0: + for i in possible_chunk_regions: + if i["region"][1] - i["region"][0] > max_region_range: + best_region = i + max_region_range = i["region"][1] - i["region"][0] + if self._is_legal_region(best_region, chunk_infos): + break + possible_chunk_regions.remove(i) + max_region_range = 0 + best_region = None + return best_region + + def _is_legal_region(self, cur_chunk_info, chunk_infos): + (chunk_region_start, chunk_region_end) = cur_chunk_info["region"] + if cur_chunk_info in chunk_infos: + return False + if chunk_region_end < chunk_region_start: + return False + for i in chunk_infos: + region = i["region"] + if not ( + (chunk_region_start > region[1] and chunk_region_end > region[1]) + or (chunk_region_start < region[0] and chunk_region_end < region[0]) + ): + return False + return True + + class ChunkRegionSearch(object): def __init__(self, gm) -> None: self.gm = gm self.index_tracer = IndexTracer(list(gm.graph.nodes)) self.index_tracer.trace_index() self.memory_estimator = MemoryEstimator(self.index_tracer) + self.chunk_selector = ChunkSelector(self.index_tracer, stratge="min_memory") def _find_peak_node(self, mem_peak): max_value = max(mem_peak) @@ -1516,36 +1564,6 @@ class ChunkRegionSearch(object): possible_chunk_region.extend(chunk_info) return possible_chunk_region - def _search_best_chunk_region(self, possible_chunk_regions, chunk_infos): - max_region_range = 0 - best_region = None - while len(possible_chunk_regions) > 0: - for i in possible_chunk_regions: - if i["region"][1] - i["region"][0] > max_region_range: - best_region = i - max_region_range = i["region"][1] - i["region"][0] - if self._is_legal_region(best_region, chunk_infos): - break - possible_chunk_regions.remove(i) - max_region_range = 0 - best_region = None - return best_region - - def _is_legal_region(self, cur_chunk_info, chunk_infos): - (chunk_region_start, chunk_region_end) = cur_chunk_info["region"] - if cur_chunk_info in chunk_infos: - return False - if chunk_region_end < chunk_region_start: - return False - for i in chunk_infos: - region = i["region"] - if not ( - (chunk_region_start > region[1] and chunk_region_end > region[1]) - or (chunk_region_start < region[0] and chunk_region_end < region[0]) - ): - return False - return True - def _step_search(self, mem_peak, active_node, chunk_regions): peak_node = self._find_peak_node(mem_peak) max_chunk_region = self._search_max_chunk_region( @@ -1556,7 +1574,7 @@ class ChunkRegionSearch(object): possible_chunk_regions = self._search_possible_chunk_regions( max_chunk_region, peak_node ) - best_chunk_region = self._search_best_chunk_region( + best_chunk_region = self.chunk_selector._select_best_chunk_region( possible_chunk_regions, chunk_regions ) best_chunk_region = self.index_tracer.reorder_all(best_chunk_region)