mirror of https://github.com/hpcaitech/ColossalAI
improve reorder efficeincy
parent
966e4ea0cb
commit
80efd70c72
|
@ -1486,6 +1486,8 @@ class ChunkSelector(object):
|
|||
"chunk_len": self._get_compute_node_num(
|
||||
region["region"][0], region["region"][1]
|
||||
),
|
||||
"reorder_chunk_info": cur_region,
|
||||
"reorder_node_list": cur_node_list
|
||||
}
|
||||
)
|
||||
# no region found
|
||||
|
@ -1495,48 +1497,47 @@ class ChunkSelector(object):
|
|||
# select the min chunk len
|
||||
chunk_len = [i["chunk_len"] for i in regions_dict]
|
||||
best_region_idx = chunk_len.index(min(chunk_len))
|
||||
best_region = regions_dict[best_region_idx]["chunk_info"]
|
||||
best_region = regions_dict[best_region_idx]
|
||||
|
||||
# get max chunk size
|
||||
best_region = self._get_fit_chunk_size(best_region, chunk_infos)
|
||||
return best_region
|
||||
|
||||
def _get_fit_chunk_size(self, chunk_info, chunk_infos):
|
||||
def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos):
|
||||
chunk_size = 1
|
||||
chunk_info["chunk_size"] = chunk_size
|
||||
reorder_chunk_info = chunk_region_dict['reorder_chunk_info']
|
||||
reorder_chunk_info["chunk_size"] = chunk_size
|
||||
cur_chunk_max_mem = 0
|
||||
# search a region
|
||||
while cur_chunk_max_mem < self.max_memory:
|
||||
chunk_size *= 2
|
||||
chunk_info["chunk_size"] = chunk_size
|
||||
cur_chunk_info = chunk_info.copy()
|
||||
cur_node_list, cur_chunk_info = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_chunk_info)
|
||||
cur_chunk_infos = chunk_infos + [cur_chunk_info]
|
||||
reorder_chunk_info["chunk_size"] = chunk_size
|
||||
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
|
||||
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
||||
cur_node_list, cur_chunk_infos
|
||||
chunk_region_dict['reorder_node_list'], cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_max_mem = max(
|
||||
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
|
||||
cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1]
|
||||
)
|
||||
# search exact size
|
||||
chunk_info = chunk_region_dict["chunk_info"]
|
||||
chunk_info["chunk_size"] = self._chunk_size_binary_search(
|
||||
chunk_size // 2, chunk_size, chunk_info, chunk_infos
|
||||
chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos
|
||||
)
|
||||
return chunk_info
|
||||
|
||||
def _chunk_size_binary_search(self, l, r, chunk_info, chunk_infos):
|
||||
def _chunk_size_binary_search(self, l, r, chunk_region_dict, chunk_infos):
|
||||
if l >= 16:
|
||||
gap = 4
|
||||
else:
|
||||
gap = 1
|
||||
chunk_info = chunk_region_dict['reorder_chunk_info']
|
||||
while r >= l + gap:
|
||||
mid = int((l + r) / 2 + 0.5)
|
||||
chunk_info["chunk_size"] = mid
|
||||
cur_chunk_info = chunk_info.copy()
|
||||
cur_node_list, cur_chunk_info = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_chunk_info)
|
||||
cur_chunk_infos = chunk_infos + [cur_chunk_info]
|
||||
cur_chunk_infos = chunk_infos + [chunk_info]
|
||||
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
||||
cur_node_list, cur_chunk_infos
|
||||
chunk_region_dict['reorder_node_list'], cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_max_mem = max(
|
||||
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
|
||||
|
@ -1904,7 +1905,7 @@ def _find_idx_by_name(name, nodes_list):
|
|||
|
||||
|
||||
def _replace_name(context, name_from, name_to):
|
||||
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ",")]
|
||||
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")")]
|
||||
for p in patterns:
|
||||
source = p[0] + name_from + p[1]
|
||||
target = p[0] + name_to + p[1]
|
||||
|
|
Loading…
Reference in New Issue