finish chunk define

pull/2364/head
oahzxl 2022-12-06 16:29:07 +08:00
parent 3b7d671206
commit f24c418bb0
1 changed files with 13 additions and 4 deletions

View File

@ -827,7 +827,7 @@ def _find_input_and_output_nodes(nodes: List[Node]):
for node in nodes:
for input_node in node._input_nodes.keys():
node_repr = repr(input_node)
if input_node not in nodes and node_repr not in input_nodes:
if input_node not in nodes and input_node not in input_nodes:
input_nodes.append(input_node)
# if a node has a user node which is not in the node list
@ -835,7 +835,7 @@ def _find_input_and_output_nodes(nodes: List[Node]):
for node in nodes:
for output_node in node.users.keys():
node_repr = repr(node)
if output_node not in nodes and node_repr not in output_nodes:
if output_node not in nodes and output_node not in output_nodes:
output_nodes.append(output_node)
return input_nodes, output_nodes
@ -848,6 +848,16 @@ def _find_idx_by_name(name, nodes_list):
raise RuntimeError("name %s not found in node list" % name)
def _replace_name(context, name_from, name_to):
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ",")]
for p in patterns:
source = p[0] + name_from + p[1]
target = p[0] + name_to + p[1]
if source in context:
context = context.replace(source, target)
return context
def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes, meta_graph):
"""Emit code with nested activation checkpoint
When we detect some of the node.activation_checkpoint is a List, we will use
@ -905,8 +915,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
if within_chunk_region:
emit_node_func(node, body)
# replace input var with chunk var
if node_idx in chunk_starts:
body[-1] = body[-1].replace(chunk_inputs[region_idx][0].name, 'chunk_tensor')
body[-1] = _replace_name(body[-1], chunk_inputs[region_idx][0].name, 'chunk_tensor')
body[-1] = ' ' + body[-1]
delete_unused_value_func(node, body, chunk_inputs_names)