diff --git a/chunk_codegen.py b/chunk_codegen.py index 21ecc343a..41fcb5a3c 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -1433,7 +1433,11 @@ class ChunkSelector(object): ): if self.stratge == "min_memory": best_region = self._select_min_memory_chunk_region( - possible_chunk_regions, chunk_infos + possible_chunk_regions, + chunk_infos, + peak_node, + max_chunk_region, + mem_peak, ) elif self.stratge == "fit_memory": best_region = self._select_fit_memory_chunk_region( @@ -1561,19 +1565,52 @@ class ChunkSelector(object): count += 1 return count - def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos): - max_region_range = 0 - best_region = None - while len(possible_chunk_regions) > 0: - for i in possible_chunk_regions: - if i["region"][1] - i["region"][0] > max_region_range: - best_region = i - max_region_range = i["region"][1] - i["region"][0] - if self._is_legal_region(best_region, chunk_infos): - break - possible_chunk_regions.remove(i) - max_region_range = 0 - best_region = None + def _select_min_memory_chunk_region( + self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak + ): + # remove illegal regions + illegal_regions = [] + for i in possible_chunk_regions: + if not self._is_legal_region(i, chunk_infos): + illegal_regions.append(i) + for i in illegal_regions: + if i in possible_chunk_regions: + possible_chunk_regions.remove(i) + + if len(possible_chunk_regions) == 0: + return None + + # get mem for chunk region + regions_dict = [] + for region in possible_chunk_regions: + cur_region = region.copy() + cur_node_list, cur_region = self.index_tracer.tmp_reorder( + self.index_tracer.node_list, cur_region + ) + cur_chunk_infos = chunk_infos + [cur_region] + cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( + cur_node_list, cur_chunk_infos + )[0] + cur_chunk_region_peak = cur_mem_peak[ + max_chunk_region[0] : max_chunk_region[1] + 1 + ] + cur_chunk_region_max_peak = max(cur_chunk_region_peak) + regions_dict.append( + { + "chunk_info": region, + "chunk_max_mem": cur_chunk_region_max_peak, + "chunk_len": self._get_compute_node_num( + region["region"][0], region["region"][1] + ), + "reorder_chunk_info": cur_region, + "reorder_node_list": cur_node_list, + } + ) + + # select the min mem + chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict] + best_region_idx = chunk_max_mem.index(min(chunk_max_mem)) + best_region = regions_dict[best_region_idx]["chunk_info"] if best_region is not None: best_region["chunk_size"] = 1 return best_region