mirror of https://github.com/hpcaitech/ColossalAI
code style
parent
b7b67c32ad
commit
5cdfcfe1d1
|
@ -92,24 +92,10 @@ class FlowTracer(object):
|
||||||
self._add_trace(i.name)
|
self._add_trace(i.name)
|
||||||
self._add_node(i.name, i)
|
self._add_node(i.name, i)
|
||||||
|
|
||||||
def _is_non_compute_node(self, node):
|
|
||||||
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(
|
|
||||||
i in node.name for i in ["getitem", "getattr"]
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _is_non_compute_node_except_placeholder(self, node):
|
|
||||||
if any(i in node.op for i in ["get_attr", "output"]) or any(
|
|
||||||
i in node.name for i in ["getitem", "getattr"]
|
|
||||||
):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _find_flow_for_node(self, node):
|
def _find_flow_for_node(self, node):
|
||||||
if type(self.node_list[0]) != type(node):
|
if type(self.node_list[0]) != type(node):
|
||||||
return None
|
return None
|
||||||
if self._is_non_compute_node_except_placeholder(node):
|
if _is_non_compute_node_except_placeholder(node):
|
||||||
return None
|
return None
|
||||||
for name, trace in self.flow_trace.items():
|
for name, trace in self.flow_trace.items():
|
||||||
for i in trace:
|
for i in trace:
|
||||||
|
@ -135,7 +121,7 @@ class FlowTracer(object):
|
||||||
raise RuntimeError("invalid node")
|
raise RuntimeError("invalid node")
|
||||||
|
|
||||||
def _get_flow_mix_node(self, node):
|
def _get_flow_mix_node(self, node):
|
||||||
if self._is_non_compute_node(node):
|
if _is_non_compute_node(node):
|
||||||
return None
|
return None
|
||||||
_, node_trace = self.find_node_flow(node)
|
_, node_trace = self.find_node_flow(node)
|
||||||
if len(node_trace["outside_depend"]) == 0:
|
if len(node_trace["outside_depend"]) == 0:
|
||||||
|
@ -160,10 +146,9 @@ class FlowTracer(object):
|
||||||
for node in self.node_list:
|
for node in self.node_list:
|
||||||
# skip if non compute node
|
# skip if non compute node
|
||||||
if all(
|
if all(
|
||||||
type(arg) != type(node)
|
type(arg) != type(node) or _is_non_compute_node_except_placeholder(arg)
|
||||||
or self._is_non_compute_node_except_placeholder(arg)
|
|
||||||
for arg in node.args
|
for arg in node.args
|
||||||
) or self._is_non_compute_node(node):
|
) or _is_non_compute_node(node):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
node_input_flows = [self._find_flow_for_node(arg) for arg in node.args]
|
node_input_flows = [self._find_flow_for_node(arg) for arg in node.args]
|
||||||
|
@ -1411,32 +1396,6 @@ def _gen_loop_end(
|
||||||
return context
|
return context
|
||||||
|
|
||||||
|
|
||||||
def _find_input_and_output_nodes(nodes: List[Node]):
|
|
||||||
"""
|
|
||||||
Find the input and output node names which are not found in the given list of nodes.
|
|
||||||
"""
|
|
||||||
input_nodes = []
|
|
||||||
output_nodes = []
|
|
||||||
|
|
||||||
# if a node has an input node which is not in the node list
|
|
||||||
# we treat that input node as the input of the checkpoint function
|
|
||||||
for node in nodes:
|
|
||||||
for input_node in node._input_nodes.keys():
|
|
||||||
node_repr = repr(input_node)
|
|
||||||
if input_node not in nodes and input_node not in input_nodes:
|
|
||||||
input_nodes.append(input_node)
|
|
||||||
|
|
||||||
# if a node has a user node which is not in the node list
|
|
||||||
# we treat that user node as the node receiving the current node output
|
|
||||||
for node in nodes:
|
|
||||||
for output_node in node.users.keys():
|
|
||||||
node_repr = repr(node)
|
|
||||||
if output_node not in nodes and output_node not in output_nodes:
|
|
||||||
output_nodes.append(output_node)
|
|
||||||
|
|
||||||
return input_nodes, output_nodes
|
|
||||||
|
|
||||||
|
|
||||||
def _find_chunk_all_input_nodes(nodes: List[Node]):
|
def _find_chunk_all_input_nodes(nodes: List[Node]):
|
||||||
"""
|
"""
|
||||||
Find non-compute input and output node names.
|
Find non-compute input and output node names.
|
||||||
|
|
Loading…
Reference in New Issue