code style

pull/2364/head
oahzxl 2022-12-12 17:29:07 +08:00
parent b7b67c32ad
commit 5cdfcfe1d1
1 changed files with 4 additions and 45 deletions

View File

@ -92,24 +92,10 @@ class FlowTracer(object):
self._add_trace(i.name)
self._add_node(i.name, i)
def _is_non_compute_node(self, node):
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(
i in node.name for i in ["getitem", "getattr"]
):
return True
return False
def _is_non_compute_node_except_placeholder(self, node):
if any(i in node.op for i in ["get_attr", "output"]) or any(
i in node.name for i in ["getitem", "getattr"]
):
return True
return False
def _find_flow_for_node(self, node):
if type(self.node_list[0]) != type(node):
return None
if self._is_non_compute_node_except_placeholder(node):
if _is_non_compute_node_except_placeholder(node):
return None
for name, trace in self.flow_trace.items():
for i in trace:
@ -135,7 +121,7 @@ class FlowTracer(object):
raise RuntimeError("invalid node")
def _get_flow_mix_node(self, node):
if self._is_non_compute_node(node):
if _is_non_compute_node(node):
return None
_, node_trace = self.find_node_flow(node)
if len(node_trace["outside_depend"]) == 0:
@ -160,10 +146,9 @@ class FlowTracer(object):
for node in self.node_list:
# skip if non compute node
if all(
type(arg) != type(node)
or self._is_non_compute_node_except_placeholder(arg)
type(arg) != type(node) or _is_non_compute_node_except_placeholder(arg)
for arg in node.args
) or self._is_non_compute_node(node):
) or _is_non_compute_node(node):
continue
node_input_flows = [self._find_flow_for_node(arg) for arg in node.args]
@ -1411,32 +1396,6 @@ def _gen_loop_end(
return context
def _find_input_and_output_nodes(nodes: List[Node]):
"""
Find the input and output node names which are not found in the given list of nodes.
"""
input_nodes = []
output_nodes = []
# if a node has an input node which is not in the node list
# we treat that input node as the input of the checkpoint function
for node in nodes:
for input_node in node._input_nodes.keys():
node_repr = repr(input_node)
if input_node not in nodes and input_node not in input_nodes:
input_nodes.append(input_node)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
for node in nodes:
for output_node in node.users.keys():
node_repr = repr(node)
if output_node not in nodes and output_node not in output_nodes:
output_nodes.append(output_node)
return input_nodes, output_nodes
def _find_chunk_all_input_nodes(nodes: List[Node]):
"""
Find non-compute input and output node names.