mirror of https://github.com/hpcaitech/ColossalAI
add chunk select class
parent
786a398a6b
commit
1b8a066592
|
@ -1368,12 +1368,60 @@ class MemoryEstimator(object):
|
|||
return act_memory_peak_log, act_memory_after_node_log, active_node_list_log
|
||||
|
||||
|
||||
class ChunkSelector(object):
|
||||
def __init__(self, index_tracer: IndexTracer, stratge) -> None:
|
||||
self.index_tracer = index_tracer
|
||||
assert stratge in ['min_memory', 'fit_memory']
|
||||
self.stratge = stratge
|
||||
self.max_memory = 800 # MB
|
||||
|
||||
def _select_best_chunk_region(self, possible_chunk_regions, chunk_infos):
|
||||
if self.stratge == 'min_memory':
|
||||
best_region = self._select_min_memory_chunk_region(possible_chunk_regions, chunk_infos)
|
||||
elif self.stratge == 'fit_memory':
|
||||
pass
|
||||
else:
|
||||
raise RuntimeError()
|
||||
return best_region
|
||||
|
||||
def _select_min_memory_chunk_region(self, possible_chunk_regions, chunk_infos):
|
||||
max_region_range = 0
|
||||
best_region = None
|
||||
while len(possible_chunk_regions) > 0:
|
||||
for i in possible_chunk_regions:
|
||||
if i["region"][1] - i["region"][0] > max_region_range:
|
||||
best_region = i
|
||||
max_region_range = i["region"][1] - i["region"][0]
|
||||
if self._is_legal_region(best_region, chunk_infos):
|
||||
break
|
||||
possible_chunk_regions.remove(i)
|
||||
max_region_range = 0
|
||||
best_region = None
|
||||
return best_region
|
||||
|
||||
def _is_legal_region(self, cur_chunk_info, chunk_infos):
|
||||
(chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
|
||||
if cur_chunk_info in chunk_infos:
|
||||
return False
|
||||
if chunk_region_end < chunk_region_start:
|
||||
return False
|
||||
for i in chunk_infos:
|
||||
region = i["region"]
|
||||
if not (
|
||||
(chunk_region_start > region[1] and chunk_region_end > region[1])
|
||||
or (chunk_region_start < region[0] and chunk_region_end < region[0])
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
class ChunkRegionSearch(object):
|
||||
def __init__(self, gm) -> None:
|
||||
self.gm = gm
|
||||
self.index_tracer = IndexTracer(list(gm.graph.nodes))
|
||||
self.index_tracer.trace_index()
|
||||
self.memory_estimator = MemoryEstimator(self.index_tracer)
|
||||
self.chunk_selector = ChunkSelector(self.index_tracer, stratge="min_memory")
|
||||
|
||||
def _find_peak_node(self, mem_peak):
|
||||
max_value = max(mem_peak)
|
||||
|
@ -1516,36 +1564,6 @@ class ChunkRegionSearch(object):
|
|||
possible_chunk_region.extend(chunk_info)
|
||||
return possible_chunk_region
|
||||
|
||||
def _search_best_chunk_region(self, possible_chunk_regions, chunk_infos):
|
||||
max_region_range = 0
|
||||
best_region = None
|
||||
while len(possible_chunk_regions) > 0:
|
||||
for i in possible_chunk_regions:
|
||||
if i["region"][1] - i["region"][0] > max_region_range:
|
||||
best_region = i
|
||||
max_region_range = i["region"][1] - i["region"][0]
|
||||
if self._is_legal_region(best_region, chunk_infos):
|
||||
break
|
||||
possible_chunk_regions.remove(i)
|
||||
max_region_range = 0
|
||||
best_region = None
|
||||
return best_region
|
||||
|
||||
def _is_legal_region(self, cur_chunk_info, chunk_infos):
|
||||
(chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
|
||||
if cur_chunk_info in chunk_infos:
|
||||
return False
|
||||
if chunk_region_end < chunk_region_start:
|
||||
return False
|
||||
for i in chunk_infos:
|
||||
region = i["region"]
|
||||
if not (
|
||||
(chunk_region_start > region[1] and chunk_region_end > region[1])
|
||||
or (chunk_region_start < region[0] and chunk_region_end < region[0])
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _step_search(self, mem_peak, active_node, chunk_regions):
|
||||
peak_node = self._find_peak_node(mem_peak)
|
||||
max_chunk_region = self._search_max_chunk_region(
|
||||
|
@ -1556,7 +1574,7 @@ class ChunkRegionSearch(object):
|
|||
possible_chunk_regions = self._search_possible_chunk_regions(
|
||||
max_chunk_region, peak_node
|
||||
)
|
||||
best_chunk_region = self._search_best_chunk_region(
|
||||
best_chunk_region = self.chunk_selector._select_best_chunk_region(
|
||||
possible_chunk_regions, chunk_regions
|
||||
)
|
||||
best_chunk_region = self.index_tracer.reorder_all(best_chunk_region)
|
||||
|
|
Loading…
Reference in New Issue