From 7fd3b45af21345cff9334682e277d7669c730814 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 2 Jan 2023 00:04:47 +0800 Subject: [PATCH] fix a bug in ones like, dont gen chunk if dim size is 1 --- autochunk_benchmark.py | 4 ++-- chunk_codegen.py | 41 +++++++++++++++++++++++++++-------------- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/autochunk_benchmark.py b/autochunk_benchmark.py index 679016438..3b48d7e46 100644 --- a/autochunk_benchmark.py +++ b/autochunk_benchmark.py @@ -16,9 +16,9 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=N torch.cuda.reset_peak_memory_stats() now_mem = torch.cuda.memory_allocated() / 1024**2 - loop = 16 + loop = 3 with torch.no_grad(): - for _ in range(loop // 4): + for _ in range(loop // 2 + 1): if chunk_size: model(node, pair, chunk_size) else: diff --git a/chunk_codegen.py b/chunk_codegen.py index 6f8ff2b23..6f21f26f3 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -144,9 +144,7 @@ class IndexTracer(object): node_to_trace_source[node_to_dim][node_from_idx] = [node_from_dim] else: if node_from_dim not in node_to_trace_source[node_to_dim][node_from_idx]: - node_to_trace_source[node_to_dim][node_from_idx].append( - node_from_dim - ) + node_to_trace_source[node_to_dim][node_from_idx].append(node_from_dim) # update inputs source for node_idx, node_dim in node_from_trace_source[node_from_dim].items(): if node_idx not in node_to_trace_source[node_to_dim]: @@ -1097,17 +1095,17 @@ class IndexTracer(object): for old_idx, new_idx in self.all_reorder_map.items(): new_node_list[new_idx] = node_list[old_idx] return new_node_list - + def tmp_reorder(self, node_list, chunk_info): if len(chunk_info["args"]["prepose_nodes"]) == 0: return node_list, chunk_info reorder_map = self._get_reorder_map(chunk_info) - + # new tmp node list new_node_list = [None for _ in range(len(node_list))] for old_idx, new_idx in reorder_map.items(): new_node_list[new_idx] = node_list[old_idx] - + chunk_info = self._reorder_chunk_info(chunk_info, reorder_map) return new_node_list, chunk_info @@ -1472,7 +1470,9 @@ class ChunkSelector(object): regions_dict = [] for region in possible_chunk_regions: cur_region = region.copy() - cur_node_list, cur_region = self.index_tracer.tmp_reorder(self.index_tracer.node_list, cur_region) + cur_node_list, cur_region = self.index_tracer.tmp_reorder( + self.index_tracer.node_list, cur_region + ) cur_chunk_infos = chunk_infos + [cur_region] cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( cur_node_list, cur_chunk_infos @@ -1490,7 +1490,7 @@ class ChunkSelector(object): region["region"][0], region["region"][1] ), "reorder_chunk_info": cur_region, - "reorder_node_list": cur_node_list + "reorder_node_list": cur_node_list, } ) # no region found @@ -1508,7 +1508,7 @@ class ChunkSelector(object): def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos): chunk_size = 1 - reorder_chunk_info = chunk_region_dict['reorder_chunk_info'] + reorder_chunk_info = chunk_region_dict["reorder_chunk_info"] reorder_chunk_info["chunk_size"] = chunk_size cur_chunk_max_mem = 0 # search a region @@ -1517,10 +1517,13 @@ class ChunkSelector(object): reorder_chunk_info["chunk_size"] = chunk_size cur_chunk_infos = chunk_infos + [reorder_chunk_info] cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - chunk_region_dict['reorder_node_list'], cur_chunk_infos + chunk_region_dict["reorder_node_list"], cur_chunk_infos )[0] cur_chunk_max_mem = max( - cur_mem_peak[reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + 1] + cur_mem_peak[ + reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1] + + 1 + ] ) # search exact size chunk_info = chunk_region_dict["chunk_info"] @@ -1534,13 +1537,13 @@ class ChunkSelector(object): gap = 4 else: gap = 1 - chunk_info = chunk_region_dict['reorder_chunk_info'] + chunk_info = chunk_region_dict["reorder_chunk_info"] while r >= l + gap: mid = int((l + r) / 2 + 0.5) chunk_info["chunk_size"] = mid cur_chunk_infos = chunk_infos + [chunk_info] cur_mem_peak = self.memory_estimator.estimate_chunk_inference_mem( - chunk_region_dict['reorder_node_list'], cur_chunk_infos + chunk_region_dict["reorder_node_list"], cur_chunk_infos )[0] cur_chunk_max_mem = max( cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1] @@ -2000,8 +2003,18 @@ def emit_code_with_chunk( ) # ones like if "ones_like" in node.name: + chunk_dim = chunk_search[region_idx]["node_chunk_dim"][ + chunk_region_search.index_tracer.node_list[node_idx] + ]["chunk_dim"] + if ( + _get_node_shape( + chunk_region_search.index_tracer.node_list[node_idx] + )[chunk_dim] + == 1 + ): + continue chunk_slice = _gen_chunk_slice_dim( - chunk_search[region_idx]["node_chunk_dim"][chunk_region_search.index_tracer.node_list[node_idx]]["chunk_dim"], "chunk_idx", _get_node_shape(node) + chunk_dim, "chunk_idx", _get_node_shape(node) ) body[-1] = _replace_name( body[-1], node.args[0].name, node.args[0].name + chunk_slice