improve reorder efficeincy

pull/2364/head
oahzxl 2022-12-31 13:44:46 +08:00
parent 966e4ea0cb
commit 80efd70c72
1 changed files with 17 additions and 16 deletions

View File

@ -1486,6 +1486,8 @@ class ChunkSelector(object):
"chunk_len": self._get_compute_node_num( "chunk_len": self._get_compute_node_num(
region["region"][0], region["region"][1] region["region"][0], region["region"][1]
), ),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list
} }
) )
# no region found # no region found
@ -1495,48 +1497,47 @@ class ChunkSelector(object):
# select the min chunk len # select the min chunk len
chunk_len = [i["chunk_len"] for i in regions_dict] chunk_len = [i["chunk_len"] for i in regions_dict]
best_region_idx = chunk_len.index(min(chunk_len)) best_region_idx = chunk_len.index(min(chunk_len))
best_region = regions_dict[best_region_idx]["chunk_info"] best_region = regions_dict[best_region_idx]
# get max chunk size # get max chunk size
best_region = self._get_fit_chunk_size(best_region, chunk_infos) best_region = self._get_fit_chunk_size(best_region, chunk_infos)
return best_region return best_region
def _get_fit_chunk_size(self, chunk_info, chunk_infos): def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos):
chunk_size = 1 chunk_size = 1
chunk_info["chunk_size"] = chunk_size reorder_chunk_info = chunk_region_dict['reorder_chunk_info']
reorder_chunk_info["chunk_size"] = chunk_size
cur_chunk_max_mem = 0 cur_chunk_max_mem = 0
# search a region # search a region
while cur_chunk_max_mem < self.max_memory: while cur_chunk_max_mem < self.max_memory:
chunk_size *= 2 chunk_size *= 2
chunk_info["chunk_size"] = chunk_size reorder_chunk_info["chunk_size"] = chunk_size
cur_chunk_info = chunk_info.copy() cur_chunk_infos = chunk_infos + [reorder_chunk_info]
cur_node_list, cur_chunk_info = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_chunk_info)
cur_chunk_infos = chunk_infos + [cur_chunk_info]
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
cur_node_list, cur_chunk_infos chunk_region_dict['reorder_node_list'], cur_chunk_infos
)[0] )[0]
cur_chunk_max_mem = max( cur_chunk_max_mem = max(
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1]
) )
# search exact size # search exact size
chunk_info = chunk_region_dict["chunk_info"]
chunk_info["chunk_size"] = self._chunk_size_binary_search( chunk_info["chunk_size"] = self._chunk_size_binary_search(
chunk_size // 2, chunk_size, chunk_info, chunk_infos chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos
) )
return chunk_info return chunk_info
def _chunk_size_binary_search(self, l, r, chunk_info, chunk_infos): def _chunk_size_binary_search(self, l, r, chunk_region_dict, chunk_infos):
if l >= 16: if l >= 16:
gap = 4 gap = 4
else: else:
gap = 1 gap = 1
chunk_info = chunk_region_dict['reorder_chunk_info']
while r >= l + gap: while r >= l + gap:
mid = int((l + r) / 2 + 0.5) mid = int((l + r) / 2 + 0.5)
chunk_info["chunk_size"] = mid chunk_info["chunk_size"] = mid
cur_chunk_info = chunk_info.copy() cur_chunk_infos = chunk_infos + [chunk_info]
cur_node_list, cur_chunk_info = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_chunk_info)
cur_chunk_infos = chunk_infos + [cur_chunk_info]
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
cur_node_list, cur_chunk_infos chunk_region_dict['reorder_node_list'], cur_chunk_infos
)[0] )[0]
cur_chunk_max_mem = max( cur_chunk_max_mem = max(
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
@ -1904,7 +1905,7 @@ def _find_idx_by_name(name, nodes_list):
def _replace_name(context, name_from, name_to): def _replace_name(context, name_from, name_to):
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ",")] patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")")]
for p in patterns: for p in patterns:
source = p[0] + name_from + p[1] source = p[0] + name_from + p[1]
target = p[0] + name_to + p[1] target = p[0] + name_to + p[1]