mirror of https://github.com/hpcaitech/ColossalAI
improve reorder efficeincy
parent
966e4ea0cb
commit
80efd70c72
|
@ -1486,6 +1486,8 @@ class ChunkSelector(object):
|
||||||
"chunk_len": self._get_compute_node_num(
|
"chunk_len": self._get_compute_node_num(
|
||||||
region["region"][0], region["region"][1]
|
region["region"][0], region["region"][1]
|
||||||
),
|
),
|
||||||
|
"reorder_chunk_info": cur_region,
|
||||||
|
"reorder_node_list": cur_node_list
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
# no region found
|
# no region found
|
||||||
|
@ -1495,48 +1497,47 @@ class ChunkSelector(object):
|
||||||
# select the min chunk len
|
# select the min chunk len
|
||||||
chunk_len = [i["chunk_len"] for i in regions_dict]
|
chunk_len = [i["chunk_len"] for i in regions_dict]
|
||||||
best_region_idx = chunk_len.index(min(chunk_len))
|
best_region_idx = chunk_len.index(min(chunk_len))
|
||||||
best_region = regions_dict[best_region_idx]["chunk_info"]
|
best_region = regions_dict[best_region_idx]
|
||||||
|
|
||||||
# get max chunk size
|
# get max chunk size
|
||||||
best_region = self._get_fit_chunk_size(best_region, chunk_infos)
|
best_region = self._get_fit_chunk_size(best_region, chunk_infos)
|
||||||
return best_region
|
return best_region
|
||||||
|
|
||||||
def _get_fit_chunk_size(self, chunk_info, chunk_infos):
|
def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos):
|
||||||
chunk_size = 1
|
chunk_size = 1
|
||||||
chunk_info["chunk_size"] = chunk_size
|
reorder_chunk_info = chunk_region_dict['reorder_chunk_info']
|
||||||
|
reorder_chunk_info["chunk_size"] = chunk_size
|
||||||
cur_chunk_max_mem = 0
|
cur_chunk_max_mem = 0
|
||||||
# search a region
|
# search a region
|
||||||
while cur_chunk_max_mem < self.max_memory:
|
while cur_chunk_max_mem < self.max_memory:
|
||||||
chunk_size *= 2
|
chunk_size *= 2
|
||||||
chunk_info["chunk_size"] = chunk_size
|
reorder_chunk_info["chunk_size"] = chunk_size
|
||||||
cur_chunk_info = chunk_info.copy()
|
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
|
||||||
cur_node_list, cur_chunk_info = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_chunk_info)
|
|
||||||
cur_chunk_infos = chunk_infos + [cur_chunk_info]
|
|
||||||
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
||||||
cur_node_list, cur_chunk_infos
|
chunk_region_dict['reorder_node_list'], cur_chunk_infos
|
||||||
)[0]
|
)[0]
|
||||||
cur_chunk_max_mem = max(
|
cur_chunk_max_mem = max(
|
||||||
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
|
cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1]
|
||||||
)
|
)
|
||||||
# search exact size
|
# search exact size
|
||||||
|
chunk_info = chunk_region_dict["chunk_info"]
|
||||||
chunk_info["chunk_size"] = self._chunk_size_binary_search(
|
chunk_info["chunk_size"] = self._chunk_size_binary_search(
|
||||||
chunk_size // 2, chunk_size, chunk_info, chunk_infos
|
chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos
|
||||||
)
|
)
|
||||||
return chunk_info
|
return chunk_info
|
||||||
|
|
||||||
def _chunk_size_binary_search(self, l, r, chunk_info, chunk_infos):
|
def _chunk_size_binary_search(self, l, r, chunk_region_dict, chunk_infos):
|
||||||
if l >= 16:
|
if l >= 16:
|
||||||
gap = 4
|
gap = 4
|
||||||
else:
|
else:
|
||||||
gap = 1
|
gap = 1
|
||||||
|
chunk_info = chunk_region_dict['reorder_chunk_info']
|
||||||
while r >= l + gap:
|
while r >= l + gap:
|
||||||
mid = int((l + r) / 2 + 0.5)
|
mid = int((l + r) / 2 + 0.5)
|
||||||
chunk_info["chunk_size"] = mid
|
chunk_info["chunk_size"] = mid
|
||||||
cur_chunk_info = chunk_info.copy()
|
cur_chunk_infos = chunk_infos + [chunk_info]
|
||||||
cur_node_list, cur_chunk_info = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_chunk_info)
|
|
||||||
cur_chunk_infos = chunk_infos + [cur_chunk_info]
|
|
||||||
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
||||||
cur_node_list, cur_chunk_infos
|
chunk_region_dict['reorder_node_list'], cur_chunk_infos
|
||||||
)[0]
|
)[0]
|
||||||
cur_chunk_max_mem = max(
|
cur_chunk_max_mem = max(
|
||||||
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
|
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
|
||||||
|
@ -1904,7 +1905,7 @@ def _find_idx_by_name(name, nodes_list):
|
||||||
|
|
||||||
|
|
||||||
def _replace_name(context, name_from, name_to):
|
def _replace_name(context, name_from, name_to):
|
||||||
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ",")]
|
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")")]
|
||||||
for p in patterns:
|
for p in patterns:
|
||||||
source = p[0] + name_from + p[1]
|
source = p[0] + name_from + p[1]
|
||||||
target = p[0] + name_to + p[1]
|
target = p[0] + name_to + p[1]
|
||||||
|
|
Loading…
Reference in New Issue