Browse Source

add chunk select class

pull/2364/head
oahzxl 2 years ago
parent
commit
1b8a066592
  1. 80
      chunk_codegen.py

80
chunk_codegen.py

@ -1368,12 +1368,60 @@ class MemoryEstimator(object):
return act_memory_peak_log, act_memory_after_node_log, active_node_list_log
class ChunkSelector(object):
def __init__(self, index_tracer: IndexTracer, stratge) -> None:
self.index_tracer = index_tracer
assert stratge in ['min_memory', 'fit_memory']
self.stratge = stratge
self.max_memory = 800 # MB
def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos):
if self.stratge == 'min_memory':
best_region = self._select_min_memory_chunk_region(possible_chunk_regions, chunk_infos)
elif self.stratge == 'fit_memory':
pass
else:
raise RuntimeError()
return best_region
def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos):
max_region_range = 0
best_region = None
while len(possible_chunk_regions) > 0:
for i in possible_chunk_regions:
if i["region"][1] - i["region"][0] > max_region_range:
best_region = i
max_region_range = i["region"][1] - i["region"][0]
if self._is_legal_region(best_region, chunk_infos):
break
possible_chunk_regions.remove(i)
max_region_range = 0
best_region = None
return best_region
def _is_legal_region(self, cur_chunk_info, chunk_infos):
(chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
if cur_chunk_info in chunk_infos:
return False
if chunk_region_end < chunk_region_start:
return False
for i in chunk_infos:
region = i["region"]
if not (
(chunk_region_start > region[1] and chunk_region_end > region[1])
or (chunk_region_start < region[0] and chunk_region_end < region[0])
):
return False
return True
class ChunkRegionSearch(object):
def __init__(self, gm) -> None:
self.gm = gm
self.index_tracer = IndexTracer(list(gm.graph.nodes))
self.index_tracer.trace_index()
self.memory_estimator = MemoryEstimator(self.index_tracer)
self.chunk_selector = ChunkSelector(self.index_tracer, stratge="min_memory")
def _find_peak_node(self, mem_peak):
max_value = max(mem_peak)
@ -1516,36 +1564,6 @@ class ChunkRegionSearch(object):
possible_chunk_region.extend(chunk_info)
return possible_chunk_region
def _search_best_chunk_region(self, possible_chunk_regions, chunk_infos):
max_region_range = 0
best_region = None
while len(possible_chunk_regions) > 0:
for i in possible_chunk_regions:
if i["region"][1] - i["region"][0] > max_region_range:
best_region = i
max_region_range = i["region"][1] - i["region"][0]
if self._is_legal_region(best_region, chunk_infos):
break
possible_chunk_regions.remove(i)
max_region_range = 0
best_region = None
return best_region
def _is_legal_region(self, cur_chunk_info, chunk_infos):
(chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
if cur_chunk_info in chunk_infos:
return False
if chunk_region_end < chunk_region_start:
return False
for i in chunk_infos:
region = i["region"]
if not (
(chunk_region_start > region[1] and chunk_region_end > region[1])
or (chunk_region_start < region[0] and chunk_region_end < region[0])
):
return False
return True
def _step_search(self, mem_peak, active_node, chunk_regions):
peak_node = self._find_peak_node(mem_peak)
max_chunk_region = self._search_max_chunk_region(
@ -1556,7 +1574,7 @@ class ChunkRegionSearch(object):
possible_chunk_regions = self._search_possible_chunk_regions(
max_chunk_region, peak_node
)
best_chunk_region = self._search_best_chunk_region(
best_chunk_region = self.chunk_selector._select_best_chunk_region(
possible_chunk_regions, chunk_regions
)
best_chunk_region = self.index_tracer.reorder_all(best_chunk_region)

Loading…
Cancel
Save