mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
191 lines
7.9 KiB
191 lines
7.9 KiB
from .estimate_memory import EstimateMemory |
|
from .reorder_graph import ReorderGraph |
|
from .trace_indice import TraceIndice |
|
from .utils import NodeMgr, is_non_compute_node |
|
|
|
|
|
class SelectChunk(object): |
|
def __init__( |
|
self, |
|
trace_indice: TraceIndice, |
|
estimate_memory: EstimateMemory, |
|
reorder_graph: ReorderGraph, |
|
node_mgr: NodeMgr, |
|
max_memory=None, |
|
): |
|
self.trace_indice = trace_indice |
|
self.estimate_memory = estimate_memory |
|
self.reorder_graph = reorder_graph |
|
self.node_mgr = node_mgr |
|
if max_memory is not None: |
|
self.stratge = "fit_memory" |
|
self.max_memory = max_memory # MB |
|
else: |
|
self.stratge = "min_memory" |
|
|
|
def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos, mem_peak): |
|
if self.stratge == "min_memory": |
|
best_region = self._select_min_memory_chunk_region(possible_chunk_regions, chunk_infos) |
|
elif self.stratge == "fit_memory": |
|
best_region = self._select_fit_memory_chunk_region(possible_chunk_regions, chunk_infos, mem_peak) |
|
else: |
|
raise RuntimeError() |
|
return best_region |
|
|
|
def _select_fit_memory_chunk_region(self, possible_chunk_regions, chunk_infos, mem_peak): |
|
# stop chunk if max memory satisfy memory limit |
|
if max(mem_peak) < self.max_memory: |
|
return None |
|
|
|
# remove illegal regions |
|
illegal_regions = [] |
|
for i in possible_chunk_regions: |
|
if not self._is_legal_region(i, chunk_infos): |
|
illegal_regions.append(i) |
|
for i in illegal_regions: |
|
if i in possible_chunk_regions: |
|
possible_chunk_regions.remove(i) |
|
|
|
if len(possible_chunk_regions) == 0: |
|
return None |
|
|
|
# get mem for chunk region |
|
regions_dict = [] |
|
for region in possible_chunk_regions: |
|
cur_region = region.copy() |
|
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region) |
|
cur_chunk_infos = chunk_infos + [cur_region] |
|
cur_mem = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0] |
|
cur_chunk_region_peak = cur_mem[cur_region["region"][0] : cur_region["region"][1] + 1] |
|
cur_chunk_region_max_peak = max(cur_chunk_region_peak) |
|
if cur_chunk_region_max_peak < self.max_memory: |
|
regions_dict.append( |
|
{ |
|
"chunk_info": region, |
|
"chunk_max_mem": cur_chunk_region_max_peak, |
|
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]), |
|
"reorder_chunk_info": cur_region, |
|
"reorder_node_list": cur_node_list, |
|
} |
|
) |
|
# no region found |
|
if len(regions_dict) == 0: |
|
raise RuntimeError("Search failed. Try a larger memory threshold.") |
|
|
|
# select the min chunk len |
|
chunk_len = [i["chunk_len"] for i in regions_dict] |
|
best_region_idx = chunk_len.index(min(chunk_len)) |
|
best_region = regions_dict[best_region_idx] |
|
|
|
# get max chunk size |
|
best_region = self._get_fit_chunk_size(best_region, chunk_infos) |
|
return best_region |
|
|
|
def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos): |
|
chunk_size = 1 |
|
reorder_chunk_info = chunk_region_dict["reorder_chunk_info"] |
|
reorder_chunk_info["chunk_size"] = chunk_size |
|
cur_chunk_max_mem = 0 |
|
# search a region |
|
while cur_chunk_max_mem < self.max_memory: |
|
chunk_size *= 2 |
|
reorder_chunk_info["chunk_size"] = chunk_size |
|
cur_chunk_infos = chunk_infos + [reorder_chunk_info] |
|
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem( |
|
chunk_region_dict["reorder_node_list"], cur_chunk_infos |
|
)[0] |
|
cur_chunk_max_mem = max(cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1]) |
|
# search exact size |
|
chunk_info = chunk_region_dict["chunk_info"] |
|
chunk_info["chunk_size"] = self._chunk_size_binary_search( |
|
chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos |
|
) |
|
return chunk_info |
|
|
|
def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos): |
|
if left >= 16: |
|
gap = 4 |
|
else: |
|
gap = 1 |
|
chunk_info = chunk_region_dict["reorder_chunk_info"] |
|
while right >= left + gap: |
|
mid = int((left + right) / 2 + 0.5) |
|
chunk_info["chunk_size"] = mid |
|
cur_chunk_infos = chunk_infos + [chunk_info] |
|
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem( |
|
chunk_region_dict["reorder_node_list"], cur_chunk_infos |
|
)[0] |
|
cur_chunk_max_mem = max(cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]) |
|
if cur_chunk_max_mem >= self.max_memory: |
|
right = mid - gap |
|
else: |
|
left = mid + gap |
|
return left |
|
|
|
def _get_compute_node_num(self, start, end): |
|
count = 0 |
|
for i in self.node_mgr.get_node_slice_by_idx(start, end + 1): |
|
if not is_non_compute_node(i): |
|
count += 1 |
|
return count |
|
|
|
def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos): |
|
# remove illegal regions |
|
illegal_regions = [] |
|
for i in possible_chunk_regions: |
|
if not self._is_legal_region(i, chunk_infos): |
|
illegal_regions.append(i) |
|
for i in illegal_regions: |
|
if i in possible_chunk_regions: |
|
possible_chunk_regions.remove(i) |
|
|
|
if len(possible_chunk_regions) == 0: |
|
return None |
|
|
|
# get max possible chunk region |
|
max_possible_chunk_region = ( |
|
min([i["region"][0] for i in possible_chunk_regions]), |
|
max([i["region"][1] for i in possible_chunk_regions]), |
|
) |
|
|
|
# get mem for chunk region |
|
regions_dict_list = [] |
|
for region in possible_chunk_regions: |
|
cur_region = region.copy() |
|
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(self.node_mgr.get_node_list(), cur_region) |
|
cur_chunk_infos = chunk_infos + [cur_region] |
|
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(cur_node_list, cur_chunk_infos)[0] |
|
cur_chunk_region_peak = cur_mem_peak[max_possible_chunk_region[0] : max_possible_chunk_region[1] + 1] |
|
cur_chunk_region_max_peak = max(cur_chunk_region_peak) |
|
regions_dict_list.append( |
|
{ |
|
"chunk_info": region, |
|
"chunk_max_mem": cur_chunk_region_max_peak, |
|
"chunk_len": self._get_compute_node_num(region["region"][0], region["region"][1]), |
|
"reorder_chunk_info": cur_region, |
|
"reorder_node_list": cur_node_list, |
|
} |
|
) |
|
|
|
# select the min mem |
|
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict_list] |
|
best_region_idx = chunk_max_mem.index(min(chunk_max_mem)) |
|
best_region = regions_dict_list[best_region_idx]["chunk_info"] |
|
if best_region is not None: |
|
best_region["chunk_size"] = 1 |
|
return best_region |
|
|
|
def _is_legal_region(self, cur_chunk_info, chunk_infos): |
|
(chunk_region_start, chunk_region_end) = cur_chunk_info["region"] |
|
if cur_chunk_info in chunk_infos: |
|
return False |
|
if chunk_region_end < chunk_region_start: |
|
return False |
|
for i in chunk_infos: |
|
region = i["region"] |
|
if not ( |
|
(chunk_region_start > region[1] and chunk_region_end > region[1]) |
|
or (chunk_region_start < region[0] and chunk_region_end < region[0]) |
|
): |
|
return False |
|
return True
|
|
|