adapt codegen to prepose node

pull/2364/head
oahzxl 2022-12-23 14:26:12 +08:00
parent 522f017418
commit d309e9338b
1 changed files with 25 additions and 18 deletions

View File

@ -1198,7 +1198,7 @@ class FlowTracer(object):
chunk_node_list.remove(n)
non_chunk_inputs = _find_chunk_all_input_nodes(chunk_node_list)
for i in non_chunk_inputs:
if i not in chunk_info["inputs"] and i not in prepose_nodes:
if i not in chunk_info["inputs"]:
chunk_info["inputs_non_chunk"].append(i)
return chunk_info
@ -1425,6 +1425,7 @@ class MemoryEstimator(object):
) / (1024**2)
# determine chunk ratio for current node
# TODO: adapt to prepose node memory
if chunk_within:
chunk_ratio = self._get_chunk_ratio(
node,
@ -1602,7 +1603,6 @@ class ChunkRegionSearch(object):
chunk_infos = []
for end_dim, end_trace_idx in enumerate(end_trace["idx"]):
if len(start_traces) > 1:
# TODO: implement multi input chunk
continue
for start_node, start_trace in start_traces.items():
for start_dim, start_trace_idx in enumerate(start_trace["idx"]):
@ -1831,7 +1831,6 @@ def _find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
# TODO: it is unsafe to remove non compute node here
for node in nodes:
for output_node in node.users.keys():
if (
@ -1900,6 +1899,8 @@ def emit_code_with_chunk(
chunk_outputs = [i["outputs"][0] for i in chunk_search]
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_search]
chunk_prepose_nodes = [i["args"]["prepose_nodes"] for i in chunk_search]
node_idx = 0
region_idx = 0
@ -1911,7 +1912,11 @@ def emit_code_with_chunk(
if node_idx in chunk_starts:
within_chunk_region = True
region_idx = chunk_starts.index(node_idx)
# add prepose nodes
for i in chunk_prepose_nodes[region_idx]:
prepose_node = node_list[_find_idx_by_name(i.name, node_list)]
emit_node_func(prepose_node, body)
delete_unused_value_func(prepose_node, body, chunk_inputs_names)
# add for loop
body.append(
_gen_loop_start(
@ -1922,20 +1927,22 @@ def emit_code_with_chunk(
)
if within_chunk_region:
emit_node_func(node, body)
# replace input var with chunk var
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
if idx == node_idx:
chunk_slice = _gen_chunk_slice_dim(
dim, "chunk_idx", _get_node_shape(input_node)
)
body[-1] = _replace_name(
body[-1], input_node.name, input_node.name + chunk_slice
)
body[-1] = " " + body[-1]
delete_unused_value_func(node, body, chunk_inputs_names)
if any(node.name == i.name for i in chunk_prepose_nodes[region_idx]):
pass
else:
emit_node_func(node, body)
# replace input var with chunk var
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
if idx == node_idx:
chunk_slice = _gen_chunk_slice_dim(
dim, "chunk_idx", _get_node_shape(input_node)
)
body[-1] = _replace_name(
body[-1], input_node.name, input_node.name + chunk_slice
)
body[-1] = " " + body[-1]
delete_unused_value_func(node, body, chunk_inputs_names)
else:
emit_node_func(node, body)
if node_idx not in chunk_inputs: