From 6be89a3b82d370be152c93dd7277e234e68eaea6 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Tue, 27 Dec 2022 14:48:25 +0800 Subject: [PATCH] add chunksize in emit, fix bug in reassgin shape --- chunk_codegen.py | 56 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 1255852d7..470768855 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -988,6 +988,7 @@ class IndexTracer(object): def _reassgin_reshape_size(self, chunk_info): chunk_region = chunk_info["region"] reshape_size = {} + chunk_shape = _get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]] for node in self.node_list[chunk_region[0] : chunk_region[1] + 1]: if any(i in node.name for i in ["reshape", "view"]): reshape_args = node.args[1:] @@ -998,7 +999,7 @@ class IndexTracer(object): if reshape_arg_dim in reshape_log["dim_to"]: continue if reshape_arg_dim == chunk_dim: - reshape_size[node.name][reshape_arg.name] = "chunk_size" + reshape_size[node.name][reshape_arg.name] = "min(chunk_size, %d - chunk_idx)" % chunk_shape chunk_info["reshape_size"] = reshape_size return chunk_info @@ -1276,7 +1277,6 @@ class MemoryEstimator(object): chunk_within = False chunk_region_idx = None chunk_ratio = 1 # use it to estimate chunk mem - chunk_size = 1 chunk_inputs_names = [] if use_chunk: @@ -1285,12 +1285,14 @@ class MemoryEstimator(object): chunk_ends = [i[1] for i in chunk_regions] chunk_inputs = [i["inputs"] for i in chunk_infos] chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] - chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [ j.name for i in chunk_inputs_non_chunk for j in i ] chunk_outputs = [i["outputs"][0] for i in chunk_infos] chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos] + chunk_sizes = [ + i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos + ] for idx, node in enumerate(node_list): # if node in chunk start nodes, change chunk ratio and add chunk_tensor @@ -1306,7 +1308,7 @@ class MemoryEstimator(object): chunk_ratio = self._get_chunk_ratio( node, chunk_node_dim[chunk_region_idx], - chunk_size, + chunk_sizes[chunk_region_idx], ) # if node is placeholder, just add the size of the node @@ -1464,8 +1466,53 @@ class ChunkSelector(object): chunk_len = [i["chunk_len"] for i in regions_dict] best_region_idx = chunk_len.index(min(chunk_len)) best_region = regions_dict[best_region_idx]["chunk_info"] + + # get max chunk size + best_region = self._get_fit_chunk_size(best_region, chunk_infos) return best_region + def _get_fit_chunk_size(self, chunk_info, chunk_infos): + chunk_size = 1 + chunk_info["chunk_size"] = chunk_size + cur_chunk_max_mem = 0 + # search a region + while cur_chunk_max_mem < self.max_memory: + chunk_size *= 2 + chunk_info["chunk_size"] = chunk_size + cur_chunk_infos = chunk_infos + [chunk_info] + cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( + self.index_tracer.node_list, cur_chunk_infos + )[0] + cur_chunk_max_mem = max( + cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] + ) + # search exact size + chunk_info["chunk_size"] = self._chunk_size_binary_search( + chunk_size // 2, chunk_size, chunk_info, chunk_infos + ) + return chunk_info + + def _chunk_size_binary_search(self, l, r, chunk_info, chunk_infos): + if l >= 16: + gap = 4 + else: + gap = 1 + while r >= l + gap: + mid = int(l + (r - l)/2) + chunk_info["chunk_size"] = mid + cur_chunk_infos = chunk_infos + [chunk_info] + cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( + self.index_tracer.node_list, cur_chunk_infos + )[0] + cur_chunk_max_mem = max( + cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] + ) + if cur_chunk_max_mem >= self.max_memory: + r = mid - gap + else: + l = mid + gap + return l + def _get_compute_node_num(self, start, end): count = 0 for i in self.index_tracer.node_list[start : end + 1]: @@ -1891,6 +1938,7 @@ def emit_code_with_chunk( chunk_inputs[region_idx], chunk_outputs[region_idx], chunk_outputs_dim[region_idx], + chunk_size=chunk_search[region_idx]["chunk_size"] ) )