mirror of https://github.com/hpcaitech/ColossalAI
finish chunk define
parent
3b7d671206
commit
f24c418bb0
|
@ -827,7 +827,7 @@ def _find_input_and_output_nodes(nodes: List[Node]):
|
|||
for node in nodes:
|
||||
for input_node in node._input_nodes.keys():
|
||||
node_repr = repr(input_node)
|
||||
if input_node not in nodes and node_repr not in input_nodes:
|
||||
if input_node not in nodes and input_node not in input_nodes:
|
||||
input_nodes.append(input_node)
|
||||
|
||||
# if a node has a user node which is not in the node list
|
||||
|
@ -835,7 +835,7 @@ def _find_input_and_output_nodes(nodes: List[Node]):
|
|||
for node in nodes:
|
||||
for output_node in node.users.keys():
|
||||
node_repr = repr(node)
|
||||
if output_node not in nodes and node_repr not in output_nodes:
|
||||
if output_node not in nodes and output_node not in output_nodes:
|
||||
output_nodes.append(output_node)
|
||||
|
||||
return input_nodes, output_nodes
|
||||
|
@ -848,6 +848,16 @@ def _find_idx_by_name(name, nodes_list):
|
|||
raise RuntimeError("name %s not found in node list" % name)
|
||||
|
||||
|
||||
def _replace_name(context, name_from, name_to):
|
||||
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ",")]
|
||||
for p in patterns:
|
||||
source = p[0] + name_from + p[1]
|
||||
target = p[0] + name_to + p[1]
|
||||
if source in context:
|
||||
context = context.replace(source, target)
|
||||
return context
|
||||
|
||||
|
||||
def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes, meta_graph):
|
||||
"""Emit code with nested activation checkpoint
|
||||
When we detect some of the node.activation_checkpoint is a List, we will use
|
||||
|
@ -905,8 +915,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
|||
if within_chunk_region:
|
||||
emit_node_func(node, body)
|
||||
# replace input var with chunk var
|
||||
if node_idx in chunk_starts:
|
||||
body[-1] = body[-1].replace(chunk_inputs[region_idx][0].name, 'chunk_tensor')
|
||||
body[-1] = _replace_name(body[-1], chunk_inputs[region_idx][0].name, 'chunk_tensor')
|
||||
body[-1] = ' ' + body[-1]
|
||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||
|
||||
|
|
Loading…
Reference in New Issue