From 9c5e028a62b003136d2402b99b728eaefcc528cd Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 2 Jan 2023 00:27:11 +0800 Subject: [PATCH] fix bug again --- chunk_codegen.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/chunk_codegen.py b/chunk_codegen.py index 6f21f26f3..21ecc343a 100644 --- a/chunk_codegen.py +++ b/chunk_codegen.py @@ -2003,22 +2003,25 @@ def emit_code_with_chunk( ) # ones like if "ones_like" in node.name: - chunk_dim = chunk_search[region_idx]["node_chunk_dim"][ - chunk_region_search.index_tracer.node_list[node_idx] - ]["chunk_dim"] - if ( - _get_node_shape( - chunk_region_search.index_tracer.node_list[node_idx] - )[chunk_dim] - == 1 - ): - continue - chunk_slice = _gen_chunk_slice_dim( - chunk_dim, "chunk_idx", _get_node_shape(node) - ) - body[-1] = _replace_name( - body[-1], node.args[0].name, node.args[0].name + chunk_slice - ) + meta_node = chunk_region_search.index_tracer.node_list[node_idx] + chunk_dim = chunk_search[region_idx]["node_chunk_dim"][meta_node][ + "chunk_dim" + ] + if _get_node_shape(meta_node)[chunk_dim] != 1: + source_node = meta_node.args[0].args[0] + if ( + source_node not in chunk_search[region_idx]["node_chunk_dim"] + or chunk_search[region_idx]["node_chunk_dim"][source_node][ + "chunk_dim" + ] + is None + ): + chunk_slice = _gen_chunk_slice_dim( + chunk_dim, "chunk_idx", _get_node_shape(node) + ) + body[-1] = _replace_name( + body[-1], node.args[0].name, node.args[0].name + chunk_slice + ) body[-1] = _replace_reshape_size( body[-1], node.name, chunk_search[region_idx]["reshape_size"] )