code style

pull/2364/head
oahzxl 2023-01-06 17:55:22 +08:00
parent c3a2bf48b4
commit 8a989a0d89
1 changed files with 40 additions and 29 deletions

View File

@ -98,6 +98,39 @@ def _replace_reshape_size(context, node_name, reshape_size_dict):
return context
def _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body):
if "ones_like" in node.name:
meta_node = search_chunk.trace_index.node_list[node_idx]
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
if get_node_shape(meta_node)[chunk_dim] != 1:
source_node = meta_node.args[0].args[0]
if (
source_node not in chunk_infos[region_idx]["node_chunk_dim"]
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"]
is None
):
chunk_slice = _gen_chunk_slice_dim(
chunk_dim, "chunk_idx", get_node_shape(node)
)
body[-1] = _replace_name(
body[-1], node.args[0].name, node.args[0].name + chunk_slice
)
return body
def _replace_input_var(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body):
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
if idx == node_idx:
chunk_slice = _gen_chunk_slice_dim(
dim[0], "chunk_idx", get_node_shape(input_node)
)
body[-1] = _replace_name(
body[-1], input_node.name, input_node.name + chunk_slice
)
return body
def emit_code_with_chunk(
body,
nodes,
@ -156,36 +189,14 @@ def emit_code_with_chunk(
if within_chunk_region:
emit_node_func(node, body)
# replace input var with chunk var
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
if idx == node_idx:
chunk_slice = _gen_chunk_slice_dim(
dim[0], "chunk_idx", get_node_shape(input_node)
)
body[-1] = _replace_name(
body[-1], input_node.name, input_node.name + chunk_slice
)
body = _replace_input_var(
chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body
)
# ones like
if "ones_like" in node.name:
meta_node = search_chunk.trace_index.node_list[node_idx]
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node][
"chunk_dim"
]
if get_node_shape(meta_node)[chunk_dim] != 1:
source_node = meta_node.args[0].args[0]
if (
source_node not in chunk_infos[region_idx]["node_chunk_dim"]
or chunk_infos[region_idx]["node_chunk_dim"][source_node][
"chunk_dim"
]
is None
):
chunk_slice = _gen_chunk_slice_dim(
chunk_dim, "chunk_idx", get_node_shape(node)
)
body[-1] = _replace_name(
body[-1], node.args[0].name, node.args[0].name + chunk_slice
)
body = _replace_ones_like(
search_chunk, chunk_infos, region_idx, node_idx, node, body
)
# reassgin reshape size
body[-1] = _replace_reshape_size(
body[-1], node.name, chunk_infos[region_idx]["reshape_size"]
)