mirror of https://github.com/hpcaitech/ColossalAI
optimise search
parent
e83e3c6154
commit
e66a18a0bf
|
@ -958,6 +958,8 @@ class MemoryEstimator(object):
|
||||||
|
|
||||||
def _add_active_node(self, n, active_list):
|
def _add_active_node(self, n, active_list):
|
||||||
new_active = self._get_output_node(n)[1]
|
new_active = self._get_output_node(n)[1]
|
||||||
|
if n.op == 'placeholder':
|
||||||
|
new_active.append(n.name)
|
||||||
for i in new_active:
|
for i in new_active:
|
||||||
if i not in active_list:
|
if i not in active_list:
|
||||||
active_list.append(i)
|
active_list.append(i)
|
||||||
|
@ -965,7 +967,7 @@ class MemoryEstimator(object):
|
||||||
def _get_delete_node(self, user, user_to_last_uses, to_keep=None):
|
def _get_delete_node(self, user, user_to_last_uses, to_keep=None):
|
||||||
delete_size = 0
|
delete_size = 0
|
||||||
delete_node = []
|
delete_node = []
|
||||||
if user.op not in ("placeholder", "output"):
|
if user.op not in ("output",):
|
||||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||||
if to_keep is not None:
|
if to_keep is not None:
|
||||||
keep_list = []
|
keep_list = []
|
||||||
|
@ -1258,24 +1260,30 @@ class ChunkRegionSearch(object):
|
||||||
|
|
||||||
def _search_max_chunk_region(self, active_node, peak_node, chunk_regions):
|
def _search_max_chunk_region(self, active_node, peak_node, chunk_regions):
|
||||||
free_vars = self._get_free_var()
|
free_vars = self._get_free_var()
|
||||||
min_var = self._get_min_free_var(active_node, free_vars)
|
free_var_num = len(free_vars)
|
||||||
|
active_node_num = [len(i) for i in active_node]
|
||||||
|
min_active_node_num = min(active_node_num[free_var_num:])
|
||||||
|
threshold = max(free_var_num, min_active_node_num)
|
||||||
|
|
||||||
# from peak_node to free_var
|
# from peak_node to free_var
|
||||||
chunk_region_start = len(free_vars)
|
inside_flag = False
|
||||||
|
chunk_region_start = free_var_num
|
||||||
for i in range(peak_node, -1, -1):
|
for i in range(peak_node, -1, -1):
|
||||||
if len(active_node[i]) == min_var:
|
if active_node_num[i] <= threshold:
|
||||||
|
inside_flag = True
|
||||||
|
if inside_flag and active_node_num[i] > threshold:
|
||||||
chunk_region_start = i + 1
|
chunk_region_start = i + 1
|
||||||
break
|
break
|
||||||
if i in free_vars or i == 0:
|
|
||||||
raise RuntimeError()
|
|
||||||
# from peak_node to len-2
|
# from peak_node to len-2
|
||||||
|
inside_flag = False
|
||||||
chunk_region_end = len(active_node) - 1
|
chunk_region_end = len(active_node) - 1
|
||||||
for i in range(peak_node, len(active_node)):
|
for i in range(peak_node, len(active_node)):
|
||||||
if len(active_node[i]) == min_var:
|
if active_node_num[i] <= threshold:
|
||||||
|
inside_flag = True
|
||||||
|
if inside_flag and active_node_num[i] > threshold:
|
||||||
chunk_region_end = i
|
chunk_region_end = i
|
||||||
break
|
break
|
||||||
if i in free_vars or i == 0:
|
|
||||||
raise RuntimeError()
|
|
||||||
|
|
||||||
for i in chunk_regions:
|
for i in chunk_regions:
|
||||||
region = i["region"]
|
region = i["region"]
|
||||||
|
@ -1374,14 +1382,33 @@ class ChunkRegionSearch(object):
|
||||||
possible_chunk_region.extend(chunk_info)
|
possible_chunk_region.extend(chunk_info)
|
||||||
return possible_chunk_region
|
return possible_chunk_region
|
||||||
|
|
||||||
def _search_best_chunk_region(self, possible_chunk_regions):
|
def _search_best_chunk_region(self, possible_chunk_regions, chunk_infos):
|
||||||
max_region_range = 0
|
max_region_range = 0
|
||||||
best_regions = None
|
best_region = None
|
||||||
|
while len(possible_chunk_regions) > 0:
|
||||||
for i in possible_chunk_regions:
|
for i in possible_chunk_regions:
|
||||||
if i["region"][1] - i["region"][0] > max_region_range:
|
if i["region"][1] - i["region"][0] > max_region_range:
|
||||||
best_regions = i
|
best_region = i
|
||||||
max_region_range = i["region"][1] - i["region"][0]
|
max_region_range = i["region"][1] - i["region"][0]
|
||||||
return best_regions
|
if self._is_legal_region(best_region, chunk_infos):
|
||||||
|
break
|
||||||
|
possible_chunk_regions.remove(i)
|
||||||
|
max_region_range = 0
|
||||||
|
best_region = None
|
||||||
|
return best_region
|
||||||
|
|
||||||
|
def _is_legal_region(self, cur_chunk_info, chunk_infos):
|
||||||
|
(chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
|
||||||
|
if cur_chunk_info in chunk_infos:
|
||||||
|
return False
|
||||||
|
if chunk_region_end < chunk_region_start:
|
||||||
|
return False
|
||||||
|
for i in chunk_infos:
|
||||||
|
region = i["region"]
|
||||||
|
if not ((chunk_region_start > region[1] and chunk_region_end > region[1])
|
||||||
|
or (chunk_region_start < region[0] and chunk_region_end < region[0])):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
def _step_search(self, mem_peak, active_node, chunk_regions):
|
def _step_search(self, mem_peak, active_node, chunk_regions):
|
||||||
peak_node = self._find_peak_node(mem_peak)
|
peak_node = self._find_peak_node(mem_peak)
|
||||||
|
@ -1393,7 +1420,7 @@ class ChunkRegionSearch(object):
|
||||||
possible_chunk_regions = self._search_possible_chunk_regions(
|
possible_chunk_regions = self._search_possible_chunk_regions(
|
||||||
max_chunk_region, peak_node
|
max_chunk_region, peak_node
|
||||||
)
|
)
|
||||||
best_chunk_region = self._search_best_chunk_region(possible_chunk_regions)
|
best_chunk_region = self._search_best_chunk_region(possible_chunk_regions, chunk_regions)
|
||||||
return best_chunk_region
|
return best_chunk_region
|
||||||
|
|
||||||
def _stop_search(self, init_mem_peak, mem_peak):
|
def _stop_search(self, init_mem_peak, mem_peak):
|
||||||
|
@ -1919,5 +1946,5 @@ if CODEGEN_AVAILABLE:
|
||||||
|
|
||||||
{prologue}
|
{prologue}
|
||||||
{code}"""
|
{code}"""
|
||||||
print(fn_code)
|
# print(fn_code)
|
||||||
return PythonCode(fn_code, globals_)
|
return PythonCode(fn_code, globals_)
|
||||||
|
|
Loading…
Reference in New Issue