mirror of https://github.com/hpcaitech/ColossalAI
add reorder in mem estimator
parent
e5a5fbb8a9
commit
966e4ea0cb
|
@ -1040,11 +1040,13 @@ class IndexTracer(object):
|
||||||
chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]),
|
chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]),
|
||||||
chunk_info["region"][1],
|
chunk_info["region"][1],
|
||||||
)
|
)
|
||||||
|
new_inputs_dim = []
|
||||||
for idx, input_dim in enumerate(chunk_info["inputs_dim"]):
|
for idx, input_dim in enumerate(chunk_info["inputs_dim"]):
|
||||||
new_input_dim = {}
|
new_input_dim = {}
|
||||||
for k, v in input_dim.items():
|
for k, v in input_dim.items():
|
||||||
new_input_dim[reorder_map[k]] = v
|
new_input_dim[reorder_map[k]] = v
|
||||||
chunk_info["inputs_dim"][idx] = new_input_dim
|
new_inputs_dim.append(new_input_dim)
|
||||||
|
chunk_info["inputs_dim"] = new_inputs_dim
|
||||||
return chunk_info
|
return chunk_info
|
||||||
|
|
||||||
def _update_all_reorder_map(self, reorder_map):
|
def _update_all_reorder_map(self, reorder_map):
|
||||||
|
@ -1095,11 +1097,24 @@ class IndexTracer(object):
|
||||||
for old_idx, new_idx in self.all_reorder_map.items():
|
for old_idx, new_idx in self.all_reorder_map.items():
|
||||||
new_node_list[new_idx] = node_list[old_idx]
|
new_node_list[new_idx] = node_list[old_idx]
|
||||||
return new_node_list
|
return new_node_list
|
||||||
|
|
||||||
|
def tmp_reorder(self, node_list, chunk_info):
|
||||||
|
if len(chunk_info["args"]["prepose_nodes"]) == 0:
|
||||||
|
return node_list, chunk_info
|
||||||
|
reorder_map = self._get_reorder_map(chunk_info)
|
||||||
|
|
||||||
|
# new tmp node list
|
||||||
|
new_node_list = [None for _ in range(len(node_list))]
|
||||||
|
for old_idx, new_idx in reorder_map.items():
|
||||||
|
new_node_list[new_idx] = node_list[old_idx]
|
||||||
|
|
||||||
|
chunk_info = self._reorder_chunk_info(chunk_info, reorder_map)
|
||||||
|
return new_node_list, chunk_info
|
||||||
|
|
||||||
|
|
||||||
class MemoryEstimator(object):
|
class MemoryEstimator(object):
|
||||||
def __init__(self, index_tracer: IndexTracer) -> None:
|
def __init__(self, index_tracer: IndexTracer) -> None:
|
||||||
self.index_tracer = index_tracer
|
pass
|
||||||
|
|
||||||
def _get_meta_node_size(self, x):
|
def _get_meta_node_size(self, x):
|
||||||
x = x.meta["tensor_meta"]
|
x = x.meta["tensor_meta"]
|
||||||
|
@ -1453,9 +1468,11 @@ class ChunkSelector(object):
|
||||||
# get mem for chunk region
|
# get mem for chunk region
|
||||||
regions_dict = []
|
regions_dict = []
|
||||||
for region in possible_chunk_regions:
|
for region in possible_chunk_regions:
|
||||||
cur_chunk_infos = chunk_infos + [region]
|
cur_region = region.copy()
|
||||||
|
cur_node_list, cur_region = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_region)
|
||||||
|
cur_chunk_infos = chunk_infos + [cur_region]
|
||||||
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
||||||
self.index_tracer.node_list, cur_chunk_infos
|
cur_node_list, cur_chunk_infos
|
||||||
)[0]
|
)[0]
|
||||||
cur_chunk_region_peak = cur_mem_peak[
|
cur_chunk_region_peak = cur_mem_peak[
|
||||||
max_chunk_region[0] : max_chunk_region[1] + 1
|
max_chunk_region[0] : max_chunk_region[1] + 1
|
||||||
|
@ -1492,9 +1509,11 @@ class ChunkSelector(object):
|
||||||
while cur_chunk_max_mem < self.max_memory:
|
while cur_chunk_max_mem < self.max_memory:
|
||||||
chunk_size *= 2
|
chunk_size *= 2
|
||||||
chunk_info["chunk_size"] = chunk_size
|
chunk_info["chunk_size"] = chunk_size
|
||||||
cur_chunk_infos = chunk_infos + [chunk_info]
|
cur_chunk_info = chunk_info.copy()
|
||||||
|
cur_node_list, cur_chunk_info = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_chunk_info)
|
||||||
|
cur_chunk_infos = chunk_infos + [cur_chunk_info]
|
||||||
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
||||||
self.index_tracer.node_list, cur_chunk_infos
|
cur_node_list, cur_chunk_infos
|
||||||
)[0]
|
)[0]
|
||||||
cur_chunk_max_mem = max(
|
cur_chunk_max_mem = max(
|
||||||
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
|
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
|
||||||
|
@ -1511,11 +1530,13 @@ class ChunkSelector(object):
|
||||||
else:
|
else:
|
||||||
gap = 1
|
gap = 1
|
||||||
while r >= l + gap:
|
while r >= l + gap:
|
||||||
mid = int(l + (r - l) / 2)
|
mid = int((l + r) / 2 + 0.5)
|
||||||
chunk_info["chunk_size"] = mid
|
chunk_info["chunk_size"] = mid
|
||||||
cur_chunk_infos = chunk_infos + [chunk_info]
|
cur_chunk_info = chunk_info.copy()
|
||||||
|
cur_node_list, cur_chunk_info = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_chunk_info)
|
||||||
|
cur_chunk_infos = chunk_infos + [cur_chunk_info]
|
||||||
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
||||||
self.index_tracer.node_list, cur_chunk_infos
|
cur_node_list, cur_chunk_infos
|
||||||
)[0]
|
)[0]
|
||||||
cur_chunk_max_mem = max(
|
cur_chunk_max_mem = max(
|
||||||
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
|
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
|
||||||
|
@ -1529,7 +1550,7 @@ class ChunkSelector(object):
|
||||||
def _get_compute_node_num(self, start, end):
|
def _get_compute_node_num(self, start, end):
|
||||||
count = 0
|
count = 0
|
||||||
for i in self.index_tracer.node_list[start : end + 1]:
|
for i in self.index_tracer.node_list[start : end + 1]:
|
||||||
if _is_non_compute_node(i):
|
if not _is_non_compute_node(i):
|
||||||
count += 1
|
count += 1
|
||||||
return count
|
return count
|
||||||
|
|
||||||
|
@ -1547,7 +1568,7 @@ class ChunkSelector(object):
|
||||||
max_region_range = 0
|
max_region_range = 0
|
||||||
best_region = None
|
best_region = None
|
||||||
if best_region is not None:
|
if best_region is not None:
|
||||||
best_region["chunk_size"] = 2
|
best_region["chunk_size"] = 1
|
||||||
return best_region
|
return best_region
|
||||||
|
|
||||||
def _is_legal_region(self, cur_chunk_info, chunk_infos):
|
def _is_legal_region(self, cur_chunk_info, chunk_infos):
|
||||||
|
|
Loading…
Reference in New Issue