mirror of https://github.com/hpcaitech/ColossalAI
222 lines
8.3 KiB
Python
222 lines
8.3 KiB
Python
from .index_tracer import IndexTracer
|
|
from .memory_estiamtor import MemoryEstimator
|
|
from .utils import is_non_compute_node
|
|
|
|
|
|
class ChunkSelector(object):
|
|
def __init__(
|
|
self,
|
|
index_tracer: IndexTracer,
|
|
memory_estimator: MemoryEstimator,
|
|
max_memory=None,
|
|
):
|
|
self.index_tracer = index_tracer
|
|
self.memory_estimator = memory_estimator
|
|
if max_memory is not None:
|
|
self.stratge = "fit_memory"
|
|
self.max_memory = max_memory # MB
|
|
else:
|
|
self.stratge = "min_memory"
|
|
|
|
def _select_best_chunk_region(
|
|
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
|
):
|
|
if self.stratge == "min_memory":
|
|
best_region = self._select_min_memory_chunk_region(
|
|
possible_chunk_regions,
|
|
chunk_infos,
|
|
peak_node,
|
|
max_chunk_region,
|
|
mem_peak,
|
|
)
|
|
elif self.stratge == "fit_memory":
|
|
best_region = self._select_fit_memory_chunk_region(
|
|
possible_chunk_regions,
|
|
chunk_infos,
|
|
peak_node,
|
|
max_chunk_region,
|
|
mem_peak,
|
|
)
|
|
else:
|
|
raise RuntimeError()
|
|
return best_region
|
|
|
|
def _select_fit_memory_chunk_region(
|
|
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
|
):
|
|
# stop chunk if max memory satisfy memory limit
|
|
if max(mem_peak) < self.max_memory:
|
|
return None
|
|
|
|
# remove illegal regions
|
|
illegal_regions = []
|
|
for i in possible_chunk_regions:
|
|
if not self._is_legal_region(i, chunk_infos):
|
|
illegal_regions.append(i)
|
|
for i in illegal_regions:
|
|
if i in possible_chunk_regions:
|
|
possible_chunk_regions.remove(i)
|
|
|
|
if len(possible_chunk_regions) == 0:
|
|
return None
|
|
|
|
# get mem for chunk region
|
|
regions_dict = []
|
|
for region in possible_chunk_regions:
|
|
cur_region = region.copy()
|
|
cur_node_list, cur_region = self.index_tracer.tmp_reorder(
|
|
self.index_tracer.node_list, cur_region
|
|
)
|
|
cur_chunk_infos = chunk_infos + [cur_region]
|
|
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
|
cur_node_list, cur_chunk_infos
|
|
)[0]
|
|
cur_chunk_region_peak = cur_mem_peak[
|
|
max_chunk_region[0] : max_chunk_region[1] + 1
|
|
]
|
|
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
|
|
if cur_chunk_region_max_peak < self.max_memory:
|
|
regions_dict.append(
|
|
{
|
|
"chunk_info": region,
|
|
"chunk_max_mem": cur_chunk_region_max_peak,
|
|
"chunk_len": self._get_compute_node_num(
|
|
region["region"][0], region["region"][1]
|
|
),
|
|
"reorder_chunk_info": cur_region,
|
|
"reorder_node_list": cur_node_list,
|
|
}
|
|
)
|
|
# no region found
|
|
if len(regions_dict) == 0:
|
|
raise RuntimeError("Search failed. Try a larger memory threshold.")
|
|
|
|
# select the min chunk len
|
|
chunk_len = [i["chunk_len"] for i in regions_dict]
|
|
best_region_idx = chunk_len.index(min(chunk_len))
|
|
best_region = regions_dict[best_region_idx]
|
|
|
|
# get max chunk size
|
|
best_region = self._get_fit_chunk_size(best_region, chunk_infos)
|
|
return best_region
|
|
|
|
def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos):
|
|
chunk_size = 1
|
|
reorder_chunk_info = chunk_region_dict["reorder_chunk_info"]
|
|
reorder_chunk_info["chunk_size"] = chunk_size
|
|
cur_chunk_max_mem = 0
|
|
# search a region
|
|
while cur_chunk_max_mem < self.max_memory:
|
|
chunk_size *= 2
|
|
reorder_chunk_info["chunk_size"] = chunk_size
|
|
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
|
|
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
|
chunk_region_dict["reorder_node_list"], cur_chunk_infos
|
|
)[0]
|
|
cur_chunk_max_mem = max(
|
|
cur_mem_peak[
|
|
reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1]
|
|
+ 1
|
|
]
|
|
)
|
|
# search exact size
|
|
chunk_info = chunk_region_dict["chunk_info"]
|
|
chunk_info["chunk_size"] = self._chunk_size_binary_search(
|
|
chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos
|
|
)
|
|
return chunk_info
|
|
|
|
def _chunk_size_binary_search(self, l, r, chunk_region_dict, chunk_infos):
|
|
if l >= 16:
|
|
gap = 4
|
|
else:
|
|
gap = 1
|
|
chunk_info = chunk_region_dict["reorder_chunk_info"]
|
|
while r >= l + gap:
|
|
mid = int((l + r) / 2 + 0.5)
|
|
chunk_info["chunk_size"] = mid
|
|
cur_chunk_infos = chunk_infos + [chunk_info]
|
|
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
|
chunk_region_dict["reorder_node_list"], cur_chunk_infos
|
|
)[0]
|
|
cur_chunk_max_mem = max(
|
|
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
|
|
)
|
|
if cur_chunk_max_mem >= self.max_memory:
|
|
r = mid - gap
|
|
else:
|
|
l = mid + gap
|
|
return l
|
|
|
|
def _get_compute_node_num(self, start, end):
|
|
count = 0
|
|
for i in self.index_tracer.node_list[start : end + 1]:
|
|
if not is_non_compute_node(i):
|
|
count += 1
|
|
return count
|
|
|
|
def _select_min_memory_chunk_region(
|
|
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
|
):
|
|
# remove illegal regions
|
|
illegal_regions = []
|
|
for i in possible_chunk_regions:
|
|
if not self._is_legal_region(i, chunk_infos):
|
|
illegal_regions.append(i)
|
|
for i in illegal_regions:
|
|
if i in possible_chunk_regions:
|
|
possible_chunk_regions.remove(i)
|
|
|
|
if len(possible_chunk_regions) == 0:
|
|
return None
|
|
|
|
# get mem for chunk region
|
|
regions_dict = []
|
|
for region in possible_chunk_regions:
|
|
cur_region = region.copy()
|
|
cur_node_list, cur_region = self.index_tracer.tmp_reorder(
|
|
self.index_tracer.node_list, cur_region
|
|
)
|
|
cur_chunk_infos = chunk_infos + [cur_region]
|
|
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
|
cur_node_list, cur_chunk_infos
|
|
)[0]
|
|
cur_chunk_region_peak = cur_mem_peak[
|
|
max_chunk_region[0] : max_chunk_region[1] + 1
|
|
]
|
|
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
|
|
regions_dict.append(
|
|
{
|
|
"chunk_info": region,
|
|
"chunk_max_mem": cur_chunk_region_max_peak,
|
|
"chunk_len": self._get_compute_node_num(
|
|
region["region"][0], region["region"][1]
|
|
),
|
|
"reorder_chunk_info": cur_region,
|
|
"reorder_node_list": cur_node_list,
|
|
}
|
|
)
|
|
|
|
# select the min mem
|
|
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict]
|
|
best_region_idx = chunk_max_mem.index(min(chunk_max_mem))
|
|
best_region = regions_dict[best_region_idx]["chunk_info"]
|
|
if best_region is not None:
|
|
best_region["chunk_size"] = 1
|
|
return best_region
|
|
|
|
def _is_legal_region(self, cur_chunk_info, chunk_infos):
|
|
(chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
|
|
if cur_chunk_info in chunk_infos:
|
|
return False
|
|
if chunk_region_end < chunk_region_start:
|
|
return False
|
|
for i in chunk_infos:
|
|
region = i["region"]
|
|
if not (
|
|
(chunk_region_start > region[1] and chunk_region_end > region[1])
|
|
or (chunk_region_start < region[0] and chunk_region_end < region[0])
|
|
):
|
|
return False
|
|
return True
|