improve reorder efficeincy

pull/2364/head
oahzxl 2022-12-31 13:44:46 +08:00
parent 966e4ea0cb
commit 80efd70c72
1 changed files with 17 additions and 16 deletions

View File

@ -1486,6 +1486,8 @@ class ChunkSelector(object):
"chunk_len": self._get_compute_node_num(
region["region"][0], region["region"][1]
),
"reorder_chunk_info": cur_region,
"reorder_node_list": cur_node_list
}
)
# no region found
@ -1495,48 +1497,47 @@ class ChunkSelector(object):
# select the min chunk len
chunk_len = [i["chunk_len"] for i in regions_dict]
best_region_idx = chunk_len.index(min(chunk_len))
best_region = regions_dict[best_region_idx]["chunk_info"]
best_region = regions_dict[best_region_idx]
# get max chunk size
best_region = self._get_fit_chunk_size(best_region, chunk_infos)
return best_region
def _get_fit_chunk_size(self, chunk_info, chunk_infos):
def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos):
chunk_size = 1
chunk_info["chunk_size"] = chunk_size
reorder_chunk_info = chunk_region_dict['reorder_chunk_info']
reorder_chunk_info["chunk_size"] = chunk_size
cur_chunk_max_mem = 0
# search a region
while cur_chunk_max_mem < self.max_memory:
chunk_size *= 2
chunk_info["chunk_size"] = chunk_size
cur_chunk_info = chunk_info.copy()
cur_node_list, cur_chunk_info = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_chunk_info)
cur_chunk_infos = chunk_infos + [cur_chunk_info]
reorder_chunk_info["chunk_size"] = chunk_size
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
cur_node_list, cur_chunk_infos
chunk_region_dict['reorder_node_list'], cur_chunk_infos
)[0]
cur_chunk_max_mem = max(
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1]
)
# search exact size
chunk_info = chunk_region_dict["chunk_info"]
chunk_info["chunk_size"] = self._chunk_size_binary_search(
chunk_size // 2, chunk_size, chunk_info, chunk_infos
chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos
)
return chunk_info
def _chunk_size_binary_search(self, l, r, chunk_info, chunk_infos):
def _chunk_size_binary_search(self, l, r, chunk_region_dict, chunk_infos):
if l >= 16:
gap = 4
else:
gap = 1
chunk_info = chunk_region_dict['reorder_chunk_info']
while r >= l + gap:
mid = int((l + r) / 2 + 0.5)
chunk_info["chunk_size"] = mid
cur_chunk_info = chunk_info.copy()
cur_node_list, cur_chunk_info = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_chunk_info)
cur_chunk_infos = chunk_infos + [cur_chunk_info]
cur_chunk_infos = chunk_infos + [chunk_info]
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
cur_node_list, cur_chunk_infos
chunk_region_dict['reorder_node_list'], cur_chunk_infos
)[0]
cur_chunk_max_mem = max(
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
@ -1904,7 +1905,7 @@ def _find_idx_by_name(name, nodes_list):
def _replace_name(context, name_from, name_to):
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ",")]
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")")]
for p in patterns:
source = p[0] + name_from + p[1]
target = p[0] + name_to + p[1]