mirror of https://github.com/hpcaitech/ColossalAI
add chunksize in emit, fix bug in reassgin shape
parent
378a49dc6c
commit
6be89a3b82
|
@ -988,6 +988,7 @@ class IndexTracer(object):
|
|||
def _reassgin_reshape_size(self, chunk_info):
|
||||
chunk_region = chunk_info["region"]
|
||||
reshape_size = {}
|
||||
chunk_shape = _get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
|
||||
for node in self.node_list[chunk_region[0] : chunk_region[1] + 1]:
|
||||
if any(i in node.name for i in ["reshape", "view"]):
|
||||
reshape_args = node.args[1:]
|
||||
|
@ -998,7 +999,7 @@ class IndexTracer(object):
|
|||
if reshape_arg_dim in reshape_log["dim_to"]:
|
||||
continue
|
||||
if reshape_arg_dim == chunk_dim:
|
||||
reshape_size[node.name][reshape_arg.name] = "chunk_size"
|
||||
reshape_size[node.name][reshape_arg.name] = "min(chunk_size, %d - chunk_idx)" % chunk_shape
|
||||
chunk_info["reshape_size"] = reshape_size
|
||||
return chunk_info
|
||||
|
||||
|
@ -1276,7 +1277,6 @@ class MemoryEstimator(object):
|
|||
chunk_within = False
|
||||
chunk_region_idx = None
|
||||
chunk_ratio = 1 # use it to estimate chunk mem
|
||||
chunk_size = 1
|
||||
chunk_inputs_names = []
|
||||
|
||||
if use_chunk:
|
||||
|
@ -1285,12 +1285,14 @@ class MemoryEstimator(object):
|
|||
chunk_ends = [i[1] for i in chunk_regions]
|
||||
chunk_inputs = [i["inputs"] for i in chunk_infos]
|
||||
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
|
||||
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos]
|
||||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [
|
||||
j.name for i in chunk_inputs_non_chunk for j in i
|
||||
]
|
||||
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
|
||||
chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos]
|
||||
chunk_sizes = [
|
||||
i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos
|
||||
]
|
||||
|
||||
for idx, node in enumerate(node_list):
|
||||
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
|
||||
|
@ -1306,7 +1308,7 @@ class MemoryEstimator(object):
|
|||
chunk_ratio = self._get_chunk_ratio(
|
||||
node,
|
||||
chunk_node_dim[chunk_region_idx],
|
||||
chunk_size,
|
||||
chunk_sizes[chunk_region_idx],
|
||||
)
|
||||
|
||||
# if node is placeholder, just add the size of the node
|
||||
|
@ -1464,8 +1466,53 @@ class ChunkSelector(object):
|
|||
chunk_len = [i["chunk_len"] for i in regions_dict]
|
||||
best_region_idx = chunk_len.index(min(chunk_len))
|
||||
best_region = regions_dict[best_region_idx]["chunk_info"]
|
||||
|
||||
# get max chunk size
|
||||
best_region = self._get_fit_chunk_size(best_region, chunk_infos)
|
||||
return best_region
|
||||
|
||||
def _get_fit_chunk_size(self, chunk_info, chunk_infos):
|
||||
chunk_size = 1
|
||||
chunk_info["chunk_size"] = chunk_size
|
||||
cur_chunk_max_mem = 0
|
||||
# search a region
|
||||
while cur_chunk_max_mem < self.max_memory:
|
||||
chunk_size *= 2
|
||||
chunk_info["chunk_size"] = chunk_size
|
||||
cur_chunk_infos = chunk_infos + [chunk_info]
|
||||
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
||||
self.index_tracer.node_list, cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_max_mem = max(
|
||||
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
|
||||
)
|
||||
# search exact size
|
||||
chunk_info["chunk_size"] = self._chunk_size_binary_search(
|
||||
chunk_size // 2, chunk_size, chunk_info, chunk_infos
|
||||
)
|
||||
return chunk_info
|
||||
|
||||
def _chunk_size_binary_search(self, l, r, chunk_info, chunk_infos):
|
||||
if l >= 16:
|
||||
gap = 4
|
||||
else:
|
||||
gap = 1
|
||||
while r >= l + gap:
|
||||
mid = int(l + (r - l)/2)
|
||||
chunk_info["chunk_size"] = mid
|
||||
cur_chunk_infos = chunk_infos + [chunk_info]
|
||||
cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem(
|
||||
self.index_tracer.node_list, cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_max_mem = max(
|
||||
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
|
||||
)
|
||||
if cur_chunk_max_mem >= self.max_memory:
|
||||
r = mid - gap
|
||||
else:
|
||||
l = mid + gap
|
||||
return l
|
||||
|
||||
def _get_compute_node_num(self, start, end):
|
||||
count = 0
|
||||
for i in self.index_tracer.node_list[start : end + 1]:
|
||||
|
@ -1891,6 +1938,7 @@ def emit_code_with_chunk(
|
|||
chunk_inputs[region_idx],
|
||||
chunk_outputs[region_idx],
|
||||
chunk_outputs_dim[region_idx],
|
||||
chunk_size=chunk_search[region_idx]["chunk_size"]
|
||||
)
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue