From f5515e9978564bddc0ff97c06c7a6933668e7cef Mon Sep 17 00:00:00 2001 From: oahzxl Date: Thu, 29 Dec 2022 16:55:47 +0800 Subject: [PATCH] use max_mem to control stratge --- chunk_codegen.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 033db50db..1c8be65d4 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1398,14 +1398,18 @@ class MemoryEstimator(object): class ChunkSelector(object): def __init__( - self, index_tracer: IndexTracer, memory_estimator: MemoryEstimator, stratge, max_memory=None + self, + index_tracer: IndexTracer, + memory_estimator: MemoryEstimator, + max_memory=None, ): self.index_tracer = index_tracer self.memory_estimator = memory_estimator - assert stratge in ["min_memory", "fit_memory"] - assert (stratge == "fit_memory" and max_memory is not None) or stratge != "fit_memory" - self.stratge = stratge - self.max_memory = max_memory # MB + if max_memory is not None: + self.stratge = "fit_memory" + self.max_memory = max_memory # MB + else: + self.stratge = "min_memory" def _select_best_chunk_region( self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak @@ -1538,6 +1542,8 @@ class ChunkSelector(object): possible_chunk_regions.remove(i) max_region_range = 0 best_region = None + if best_region is not None: + best_region["chunk_size"] = 2 return best_region def _is_legal_region(self, cur_chunk_info, chunk_infos): @@ -1563,7 +1569,7 @@ class ChunkRegionSearch(object): self.index_tracer.trace_index() self.memory_estimator = MemoryEstimator(self.index_tracer) self.chunk_selector = ChunkSelector( - self.index_tracer, self.memory_estimator, stratge="fit_memory", max_memory=max_memory + self.index_tracer, self.memory_estimator, max_memory=max_memory ) def _find_peak_node(self, mem_peak): @@ -2233,7 +2239,7 @@ if CODEGEN_AVAILABLE: delete_unused_values, self.meta_node, self.meta_graph, - self.max_memory + self.max_memory, ) if len(body) == 0: