mirror of https://github.com/hpcaitech/ColossalAI
optimise search
parent
e83e3c6154
commit
e66a18a0bf
|
@ -958,6 +958,8 @@ class MemoryEstimator(object):
|
|||
|
||||
def _add_active_node(self, n, active_list):
|
||||
new_active = self._get_output_node(n)[1]
|
||||
if n.op == 'placeholder':
|
||||
new_active.append(n.name)
|
||||
for i in new_active:
|
||||
if i not in active_list:
|
||||
active_list.append(i)
|
||||
|
@ -965,7 +967,7 @@ class MemoryEstimator(object):
|
|||
def _get_delete_node(self, user, user_to_last_uses, to_keep=None):
|
||||
delete_size = 0
|
||||
delete_node = []
|
||||
if user.op not in ("placeholder", "output"):
|
||||
if user.op not in ("output",):
|
||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
if to_keep is not None:
|
||||
keep_list = []
|
||||
|
@ -1258,24 +1260,30 @@ class ChunkRegionSearch(object):
|
|||
|
||||
def _search_max_chunk_region(self, active_node, peak_node, chunk_regions):
|
||||
free_vars = self._get_free_var()
|
||||
min_var = self._get_min_free_var(active_node, free_vars)
|
||||
|
||||
free_var_num = len(free_vars)
|
||||
active_node_num = [len(i) for i in active_node]
|
||||
min_active_node_num = min(active_node_num[free_var_num:])
|
||||
threshold = max(free_var_num, min_active_node_num)
|
||||
|
||||
# from peak_node to free_var
|
||||
chunk_region_start = len(free_vars)
|
||||
inside_flag = False
|
||||
chunk_region_start = free_var_num
|
||||
for i in range(peak_node, -1, -1):
|
||||
if len(active_node[i]) == min_var:
|
||||
if active_node_num[i] <= threshold:
|
||||
inside_flag = True
|
||||
if inside_flag and active_node_num[i] > threshold:
|
||||
chunk_region_start = i + 1
|
||||
break
|
||||
if i in free_vars or i == 0:
|
||||
raise RuntimeError()
|
||||
|
||||
# from peak_node to len-2
|
||||
inside_flag = False
|
||||
chunk_region_end = len(active_node) - 1
|
||||
for i in range(peak_node, len(active_node)):
|
||||
if len(active_node[i]) == min_var:
|
||||
if active_node_num[i] <= threshold:
|
||||
inside_flag = True
|
||||
if inside_flag and active_node_num[i] > threshold:
|
||||
chunk_region_end = i
|
||||
break
|
||||
if i in free_vars or i == 0:
|
||||
raise RuntimeError()
|
||||
|
||||
for i in chunk_regions:
|
||||
region = i["region"]
|
||||
|
@ -1374,15 +1382,34 @@ class ChunkRegionSearch(object):
|
|||
possible_chunk_region.extend(chunk_info)
|
||||
return possible_chunk_region
|
||||
|
||||
def _search_best_chunk_region(self, possible_chunk_regions):
|
||||
def _search_best_chunk_region(self, possible_chunk_regions, chunk_infos):
|
||||
max_region_range = 0
|
||||
best_regions = None
|
||||
for i in possible_chunk_regions:
|
||||
if i["region"][1] - i["region"][0] > max_region_range:
|
||||
best_regions = i
|
||||
max_region_range = i["region"][1] - i["region"][0]
|
||||
return best_regions
|
||||
|
||||
best_region = None
|
||||
while len(possible_chunk_regions) > 0:
|
||||
for i in possible_chunk_regions:
|
||||
if i["region"][1] - i["region"][0] > max_region_range:
|
||||
best_region = i
|
||||
max_region_range = i["region"][1] - i["region"][0]
|
||||
if self._is_legal_region(best_region, chunk_infos):
|
||||
break
|
||||
possible_chunk_regions.remove(i)
|
||||
max_region_range = 0
|
||||
best_region = None
|
||||
return best_region
|
||||
|
||||
def _is_legal_region(self, cur_chunk_info, chunk_infos):
|
||||
(chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
|
||||
if cur_chunk_info in chunk_infos:
|
||||
return False
|
||||
if chunk_region_end < chunk_region_start:
|
||||
return False
|
||||
for i in chunk_infos:
|
||||
region = i["region"]
|
||||
if not ((chunk_region_start > region[1] and chunk_region_end > region[1])
|
||||
or (chunk_region_start < region[0] and chunk_region_end < region[0])):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _step_search(self, mem_peak, active_node, chunk_regions):
|
||||
peak_node = self._find_peak_node(mem_peak)
|
||||
max_chunk_region = self._search_max_chunk_region(
|
||||
|
@ -1393,7 +1420,7 @@ class ChunkRegionSearch(object):
|
|||
possible_chunk_regions = self._search_possible_chunk_regions(
|
||||
max_chunk_region, peak_node
|
||||
)
|
||||
best_chunk_region = self._search_best_chunk_region(possible_chunk_regions)
|
||||
best_chunk_region = self._search_best_chunk_region(possible_chunk_regions, chunk_regions)
|
||||
return best_chunk_region
|
||||
|
||||
def _stop_search(self, init_mem_peak, mem_peak):
|
||||
|
@ -1919,5 +1946,5 @@ if CODEGEN_AVAILABLE:
|
|||
|
||||
{prologue}
|
||||
{code}"""
|
||||
print(fn_code)
|
||||
# print(fn_code)
|
||||
return PythonCode(fn_code, globals_)
|
||||
|
|
Loading…
Reference in New Issue