mirror of https://github.com/hpcaitech/ColossalAI
finish chunk define
parent
3b7d671206
commit
f24c418bb0
|
@ -827,7 +827,7 @@ def _find_input_and_output_nodes(nodes: List[Node]):
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
for input_node in node._input_nodes.keys():
|
for input_node in node._input_nodes.keys():
|
||||||
node_repr = repr(input_node)
|
node_repr = repr(input_node)
|
||||||
if input_node not in nodes and node_repr not in input_nodes:
|
if input_node not in nodes and input_node not in input_nodes:
|
||||||
input_nodes.append(input_node)
|
input_nodes.append(input_node)
|
||||||
|
|
||||||
# if a node has a user node which is not in the node list
|
# if a node has a user node which is not in the node list
|
||||||
|
@ -835,7 +835,7 @@ def _find_input_and_output_nodes(nodes: List[Node]):
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
for output_node in node.users.keys():
|
for output_node in node.users.keys():
|
||||||
node_repr = repr(node)
|
node_repr = repr(node)
|
||||||
if output_node not in nodes and node_repr not in output_nodes:
|
if output_node not in nodes and output_node not in output_nodes:
|
||||||
output_nodes.append(output_node)
|
output_nodes.append(output_node)
|
||||||
|
|
||||||
return input_nodes, output_nodes
|
return input_nodes, output_nodes
|
||||||
|
@ -848,6 +848,16 @@ def _find_idx_by_name(name, nodes_list):
|
||||||
raise RuntimeError("name %s not found in node list" % name)
|
raise RuntimeError("name %s not found in node list" % name)
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_name(context, name_from, name_to):
|
||||||
|
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ",")]
|
||||||
|
for p in patterns:
|
||||||
|
source = p[0] + name_from + p[1]
|
||||||
|
target = p[0] + name_to + p[1]
|
||||||
|
if source in context:
|
||||||
|
context = context.replace(source, target)
|
||||||
|
return context
|
||||||
|
|
||||||
|
|
||||||
def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes, meta_graph):
|
def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func, meta_nodes, meta_graph):
|
||||||
"""Emit code with nested activation checkpoint
|
"""Emit code with nested activation checkpoint
|
||||||
When we detect some of the node.activation_checkpoint is a List, we will use
|
When we detect some of the node.activation_checkpoint is a List, we will use
|
||||||
|
@ -905,8 +915,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
|
||||||
if within_chunk_region:
|
if within_chunk_region:
|
||||||
emit_node_func(node, body)
|
emit_node_func(node, body)
|
||||||
# replace input var with chunk var
|
# replace input var with chunk var
|
||||||
if node_idx in chunk_starts:
|
body[-1] = _replace_name(body[-1], chunk_inputs[region_idx][0].name, 'chunk_tensor')
|
||||||
body[-1] = body[-1].replace(chunk_inputs[region_idx][0].name, 'chunk_tensor')
|
|
||||||
body[-1] = ' ' + body[-1]
|
body[-1] = ' ' + body[-1]
|
||||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue