optimise search

pull/2364/head
oahzxl 2022-12-16 15:06:39 +08:00
parent e83e3c6154
commit e66a18a0bf
1 changed files with 47 additions and 20 deletions

View File

@ -958,6 +958,8 @@ class MemoryEstimator(object):
def _add_active_node(self, n, active_list):
new_active = self._get_output_node(n)[1]
if n.op == 'placeholder':
new_active.append(n.name)
for i in new_active:
if i not in active_list:
active_list.append(i)
@ -965,7 +967,7 @@ class MemoryEstimator(object):
def _get_delete_node(self, user, user_to_last_uses, to_keep=None):
delete_size = 0
delete_node = []
if user.op not in ("placeholder", "output"):
if user.op not in ("output",):
nodes_to_delete = user_to_last_uses.get(user, [])
if to_keep is not None:
keep_list = []
@ -1258,24 +1260,30 @@ class ChunkRegionSearch(object):
def _search_max_chunk_region(self, active_node, peak_node, chunk_regions):
free_vars = self._get_free_var()
min_var = self._get_min_free_var(active_node, free_vars)
free_var_num = len(free_vars)
active_node_num = [len(i) for i in active_node]
min_active_node_num = min(active_node_num[free_var_num:])
threshold = max(free_var_num, min_active_node_num)
# from peak_node to free_var
chunk_region_start = len(free_vars)
inside_flag = False
chunk_region_start = free_var_num
for i in range(peak_node, -1, -1):
if len(active_node[i]) == min_var:
if active_node_num[i] <= threshold:
inside_flag = True
if inside_flag and active_node_num[i] > threshold:
chunk_region_start = i + 1
break
if i in free_vars or i == 0:
raise RuntimeError()
# from peak_node to len-2
inside_flag = False
chunk_region_end = len(active_node) - 1
for i in range(peak_node, len(active_node)):
if len(active_node[i]) == min_var:
if active_node_num[i] <= threshold:
inside_flag = True
if inside_flag and active_node_num[i] > threshold:
chunk_region_end = i
break
if i in free_vars or i == 0:
raise RuntimeError()
for i in chunk_regions:
region = i["region"]
@ -1374,15 +1382,34 @@ class ChunkRegionSearch(object):
possible_chunk_region.extend(chunk_info)
return possible_chunk_region
def _search_best_chunk_region(self, possible_chunk_regions):
def _search_best_chunk_region(self, possible_chunk_regions, chunk_infos):
max_region_range = 0
best_regions = None
for i in possible_chunk_regions:
if i["region"][1] - i["region"][0] > max_region_range:
best_regions = i
max_region_range = i["region"][1] - i["region"][0]
return best_regions
best_region = None
while len(possible_chunk_regions) > 0:
for i in possible_chunk_regions:
if i["region"][1] - i["region"][0] > max_region_range:
best_region = i
max_region_range = i["region"][1] - i["region"][0]
if self._is_legal_region(best_region, chunk_infos):
break
possible_chunk_regions.remove(i)
max_region_range = 0
best_region = None
return best_region
def _is_legal_region(self, cur_chunk_info, chunk_infos):
(chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
if cur_chunk_info in chunk_infos:
return False
if chunk_region_end < chunk_region_start:
return False
for i in chunk_infos:
region = i["region"]
if not ((chunk_region_start > region[1] and chunk_region_end > region[1])
or (chunk_region_start < region[0] and chunk_region_end < region[0])):
return False
return True
def _step_search(self, mem_peak, active_node, chunk_regions):
peak_node = self._find_peak_node(mem_peak)
max_chunk_region = self._search_max_chunk_region(
@ -1393,7 +1420,7 @@ class ChunkRegionSearch(object):
possible_chunk_regions = self._search_possible_chunk_regions(
max_chunk_region, peak_node
)
best_chunk_region = self._search_best_chunk_region(possible_chunk_regions)
best_chunk_region = self._search_best_chunk_region(possible_chunk_regions, chunk_regions)
return best_chunk_region
def _stop_search(self, init_mem_peak, mem_peak):
@ -1919,5 +1946,5 @@ if CODEGEN_AVAILABLE:
{prologue}
{code}"""
print(fn_code)
# print(fn_code)
return PythonCode(fn_code, globals_)