mirror of https://github.com/hpcaitech/ColossalAI
add chunk select class
parent
786a398a6b
commit
1b8a066592
|
@ -1368,12 +1368,60 @@ class MemoryEstimator(object):
|
||||||
return act_memory_peak_log, act_memory_after_node_log, active_node_list_log
|
return act_memory_peak_log, act_memory_after_node_log, active_node_list_log
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkSelector(object):
|
||||||
|
def __init__(self, index_tracer: IndexTracer, stratge) -> None:
|
||||||
|
self.index_tracer = index_tracer
|
||||||
|
assert stratge in ['min_memory', 'fit_memory']
|
||||||
|
self.stratge = stratge
|
||||||
|
self.max_memory = 800 # MB
|
||||||
|
|
||||||
|
def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos):
|
||||||
|
if self.stratge == 'min_memory':
|
||||||
|
best_region = self._select_min_memory_chunk_region(possible_chunk_regions, chunk_infos)
|
||||||
|
elif self.stratge == 'fit_memory':
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise RuntimeError()
|
||||||
|
return best_region
|
||||||
|
|
||||||
|
def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos):
|
||||||
|
max_region_range = 0
|
||||||
|
best_region = None
|
||||||
|
while len(possible_chunk_regions) > 0:
|
||||||
|
for i in possible_chunk_regions:
|
||||||
|
if i["region"][1] - i["region"][0] > max_region_range:
|
||||||
|
best_region = i
|
||||||
|
max_region_range = i["region"][1] - i["region"][0]
|
||||||
|
if self._is_legal_region(best_region, chunk_infos):
|
||||||
|
break
|
||||||
|
possible_chunk_regions.remove(i)
|
||||||
|
max_region_range = 0
|
||||||
|
best_region = None
|
||||||
|
return best_region
|
||||||
|
|
||||||
|
def _is_legal_region(self, cur_chunk_info, chunk_infos):
|
||||||
|
(chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
|
||||||
|
if cur_chunk_info in chunk_infos:
|
||||||
|
return False
|
||||||
|
if chunk_region_end < chunk_region_start:
|
||||||
|
return False
|
||||||
|
for i in chunk_infos:
|
||||||
|
region = i["region"]
|
||||||
|
if not (
|
||||||
|
(chunk_region_start > region[1] and chunk_region_end > region[1])
|
||||||
|
or (chunk_region_start < region[0] and chunk_region_end < region[0])
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class ChunkRegionSearch(object):
|
class ChunkRegionSearch(object):
|
||||||
def __init__(self, gm) -> None:
|
def __init__(self, gm) -> None:
|
||||||
self.gm = gm
|
self.gm = gm
|
||||||
self.index_tracer = IndexTracer(list(gm.graph.nodes))
|
self.index_tracer = IndexTracer(list(gm.graph.nodes))
|
||||||
self.index_tracer.trace_index()
|
self.index_tracer.trace_index()
|
||||||
self.memory_estimator = MemoryEstimator(self.index_tracer)
|
self.memory_estimator = MemoryEstimator(self.index_tracer)
|
||||||
|
self.chunk_selector = ChunkSelector(self.index_tracer, stratge="min_memory")
|
||||||
|
|
||||||
def _find_peak_node(self, mem_peak):
|
def _find_peak_node(self, mem_peak):
|
||||||
max_value = max(mem_peak)
|
max_value = max(mem_peak)
|
||||||
|
@ -1516,36 +1564,6 @@ class ChunkRegionSearch(object):
|
||||||
possible_chunk_region.extend(chunk_info)
|
possible_chunk_region.extend(chunk_info)
|
||||||
return possible_chunk_region
|
return possible_chunk_region
|
||||||
|
|
||||||
def _search_best_chunk_region(self, possible_chunk_regions, chunk_infos):
|
|
||||||
max_region_range = 0
|
|
||||||
best_region = None
|
|
||||||
while len(possible_chunk_regions) > 0:
|
|
||||||
for i in possible_chunk_regions:
|
|
||||||
if i["region"][1] - i["region"][0] > max_region_range:
|
|
||||||
best_region = i
|
|
||||||
max_region_range = i["region"][1] - i["region"][0]
|
|
||||||
if self._is_legal_region(best_region, chunk_infos):
|
|
||||||
break
|
|
||||||
possible_chunk_regions.remove(i)
|
|
||||||
max_region_range = 0
|
|
||||||
best_region = None
|
|
||||||
return best_region
|
|
||||||
|
|
||||||
def _is_legal_region(self, cur_chunk_info, chunk_infos):
|
|
||||||
(chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
|
|
||||||
if cur_chunk_info in chunk_infos:
|
|
||||||
return False
|
|
||||||
if chunk_region_end < chunk_region_start:
|
|
||||||
return False
|
|
||||||
for i in chunk_infos:
|
|
||||||
region = i["region"]
|
|
||||||
if not (
|
|
||||||
(chunk_region_start > region[1] and chunk_region_end > region[1])
|
|
||||||
or (chunk_region_start < region[0] and chunk_region_end < region[0])
|
|
||||||
):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _step_search(self, mem_peak, active_node, chunk_regions):
|
def _step_search(self, mem_peak, active_node, chunk_regions):
|
||||||
peak_node = self._find_peak_node(mem_peak)
|
peak_node = self._find_peak_node(mem_peak)
|
||||||
max_chunk_region = self._search_max_chunk_region(
|
max_chunk_region = self._search_max_chunk_region(
|
||||||
|
@ -1556,7 +1574,7 @@ class ChunkRegionSearch(object):
|
||||||
possible_chunk_regions = self._search_possible_chunk_regions(
|
possible_chunk_regions = self._search_possible_chunk_regions(
|
||||||
max_chunk_region, peak_node
|
max_chunk_region, peak_node
|
||||||
)
|
)
|
||||||
best_chunk_region = self._search_best_chunk_region(
|
best_chunk_region = self.chunk_selector._select_best_chunk_region(
|
||||||
possible_chunk_regions, chunk_regions
|
possible_chunk_regions, chunk_regions
|
||||||
)
|
)
|
||||||
best_chunk_region = self.index_tracer.reorder_all(best_chunk_region)
|
best_chunk_region = self.index_tracer.reorder_all(best_chunk_region)
|
||||||
|
|
Loading…
Reference in New Issue