|
|
|
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
|
|
|
|
|
|
|
|
from torch.fx.node import Node
|
|
|
|
|
|
|
|
from colossalai.logging import get_dist_logger
|
|
|
|
|
|
|
|
NON_COMPUTE_OP = ["placeholder", "get_attr", "output"]
|
|
|
|
NON_COMPUTE_NAME = ["getattr", "eq", "_assert_is_none", "_assert", "finfo", "size"]
|
|
|
|
logger = get_dist_logger()
|
|
|
|
|
|
|
|
|
|
|
|
class NodeMgr(object):
|
|
|
|
|
|
|
|
def __init__(self, nodes_list: List[Node]) -> None:
|
|
|
|
self._node_list = nodes_list
|
|
|
|
self._node_dict = {}
|
|
|
|
self._set_node_dict()
|
|
|
|
|
|
|
|
def _set_node_dict(self) -> None:
|
|
|
|
"""
|
|
|
|
create a dict {node_name: node_idx}
|
|
|
|
"""
|
|
|
|
self._node_dict.clear()
|
|
|
|
for idx, node in enumerate(self._node_list):
|
|
|
|
self._node_dict[node.name] = idx
|
|
|
|
|
|
|
|
def find_node_idx(self, node: Node) -> int:
|
|
|
|
"""
|
|
|
|
find node's index
|
|
|
|
"""
|
|
|
|
return self._node_dict[node.name]
|
|
|
|
|
|
|
|
def find_node_idx_by_name(self, node_name: str) -> int:
|
|
|
|
"""
|
|
|
|
find node's index
|
|
|
|
"""
|
|
|
|
return self._node_dict[node_name]
|
|
|
|
|
|
|
|
def get_node_by_idx(self, idx: int) -> Node:
|
|
|
|
"""
|
|
|
|
get a node by index
|
|
|
|
"""
|
|
|
|
return self._node_list[idx]
|
|
|
|
|
|
|
|
def get_node_slice_by_idx(self, start: int, end: int) -> List[Node]:
|
|
|
|
"""
|
|
|
|
get a slice of node by index
|
|
|
|
"""
|
|
|
|
return self._node_list[start:end]
|
|
|
|
|
|
|
|
def get_node_list(self) -> List:
|
|
|
|
"""
|
|
|
|
get full node list
|
|
|
|
"""
|
|
|
|
return self._node_list
|
|
|
|
|
|
|
|
def update_node_list(self, node_list: List) -> None:
|
|
|
|
"""
|
|
|
|
update node list, reset node dict
|
|
|
|
"""
|
|
|
|
self._node_list = node_list
|
|
|
|
self._set_node_dict()
|
|
|
|
|
|
|
|
|
|
|
|
def get_logger() -> Any:
|
|
|
|
return logger
|
|
|
|
|
|
|
|
|
|
|
|
def flat_list(inputs: Any) -> List:
|
|
|
|
"""
|
|
|
|
flat a list by recursion
|
|
|
|
"""
|
|
|
|
if not (isinstance(inputs, list) or isinstance(inputs, set) or isinstance(inputs, tuple)):
|
|
|
|
return [inputs]
|
|
|
|
res = []
|
|
|
|
for i in inputs:
|
|
|
|
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
|
|
|
|
res.extend(flat_list(i))
|
|
|
|
elif isinstance(i, dict):
|
|
|
|
res.extend(flat_list(list(i.keys())))
|
|
|
|
else:
|
|
|
|
res.append(i)
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
def find_first_tensor_arg(node: Node) -> Node:
|
|
|
|
"""
|
|
|
|
Find the first input tensor arg for a node
|
|
|
|
"""
|
|
|
|
for arg in node.args:
|
|
|
|
if type(arg) == type(node):
|
|
|
|
return arg
|
|
|
|
raise RuntimeError()
|
|
|
|
|
|
|
|
|
|
|
|
def is_non_compute_node(node: Node) -> bool:
|
|
|
|
if any(i == node.op for i in NON_COMPUTE_OP) or any(i == get_node_name(node) for i in NON_COMPUTE_NAME):
|
|
|
|
return True
|
|
|
|
if "getitem" in node.name:
|
|
|
|
if get_node_shape(node) is not None:
|
|
|
|
return False
|
|
|
|
node_args = flat_list(node.args[1:])
|
|
|
|
for node_arg in node_args:
|
|
|
|
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
|
|
|
|
return False
|
|
|
|
if "slice" in str(node_arg):
|
|
|
|
return False
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
def get_node_shape(node: Node) -> Any:
|
|
|
|
"""
|
|
|
|
return node data shape
|
|
|
|
"""
|
|
|
|
if get_node_name(node) in ["split", "unbind"]:
|
|
|
|
return node.meta["tensor_meta"][0].shape
|
|
|
|
if hasattr(node.meta["tensor_meta"], "shape"):
|
|
|
|
return node.meta["tensor_meta"].shape
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def is_non_memory_node(node: Node) -> bool:
|
|
|
|
if "getitem" in node.name:
|
|
|
|
return True
|
|
|
|
if "output" in node.op:
|
|
|
|
return True
|
|
|
|
return is_non_compute_node(node)
|
|
|
|
|
|
|
|
|
|
|
|
def is_non_compute_node_except_placeholder(node: Node) -> bool:
|
|
|
|
if "placeholder" in node.op:
|
|
|
|
return False
|
|
|
|
return is_non_compute_node(node)
|
|
|
|
|
|
|
|
|
|
|
|
def is_non_compute_node_except_placeholder_output(node: Node) -> bool:
|
|
|
|
if "output" in node.op:
|
|
|
|
return False
|
|
|
|
return is_non_compute_node_except_placeholder(node)
|
|
|
|
|
|
|
|
|
|
|
|
def delete_free_var_from_last_use(user_to_last_uses: Dict) -> None:
|
|
|
|
for key, value in user_to_last_uses.items():
|
|
|
|
for n in value:
|
|
|
|
if n.op == "placeholder":
|
|
|
|
user_to_last_uses[key].remove(n)
|
|
|
|
|
|
|
|
|
|
|
|
def find_chunk_all_input_nodes(nodes: List[Node]) -> List:
|
|
|
|
"""
|
|
|
|
Find non-compute input and output node names.
|
|
|
|
input nodes are nodes used in the list
|
|
|
|
output nodes are nodes will use nodes in the list
|
|
|
|
"""
|
|
|
|
input_nodes = []
|
|
|
|
for node in nodes:
|
|
|
|
for input_node in node._input_nodes.keys():
|
|
|
|
if input_node not in nodes and input_node not in input_nodes:
|
|
|
|
input_nodes.append(input_node)
|
|
|
|
return input_nodes
|
|
|
|
|
|
|
|
|
|
|
|
def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List, List]:
|
|
|
|
"""
|
|
|
|
Find non-compute input and output node names.
|
|
|
|
input nodes are nodes used in the list
|
|
|
|
output nodes are nodes will use nodes in the list
|
|
|
|
"""
|
|
|
|
input_nodes = []
|
|
|
|
output_nodes = []
|
|
|
|
|
|
|
|
# if a node has an input node which is not in the node list
|
|
|
|
# we treat that input node as the input of the checkpoint function
|
|
|
|
for node in nodes:
|
|
|
|
for input_node in node._input_nodes.keys():
|
|
|
|
if (input_node not in nodes and input_node not in input_nodes
|
|
|
|
and not is_non_compute_node_except_placeholder(input_node)):
|
|
|
|
input_nodes.append(input_node)
|
|
|
|
|
|
|
|
# if a node has a user node which is not in the node list
|
|
|
|
# we treat that user node as the node receiving the current node output
|
|
|
|
for node in nodes:
|
|
|
|
for output_node in node.users.keys():
|
|
|
|
if (output_node not in nodes and node not in output_nodes
|
|
|
|
and not is_non_compute_node_except_placeholder_output(output_node)):
|
|
|
|
output_nodes.append(node)
|
|
|
|
|
|
|
|
return input_nodes, output_nodes
|
|
|
|
|
|
|
|
|
|
|
|
def get_module_node_name(node: Node) -> str:
|
|
|
|
"""
|
|
|
|
get module class name
|
|
|
|
"""
|
|
|
|
node_targets = node.target.split(".")
|
|
|
|
module = node.graph.owning_module
|
|
|
|
for i in node_targets:
|
|
|
|
module = getattr(module, i)
|
|
|
|
module_name = str(module.__class__).split(".")[-1][:-2]
|
|
|
|
module_name = module_name.lower()
|
|
|
|
return module_name
|
|
|
|
|
|
|
|
|
|
|
|
def get_node_name(node: Node) -> str:
|
|
|
|
"""
|
|
|
|
get node name
|
|
|
|
"""
|
|
|
|
node_name = node.name
|
|
|
|
if "_" in node_name:
|
|
|
|
for i in range(len(node_name) - 1, -1, -1):
|
|
|
|
if node_name[i] == "_":
|
|
|
|
node_name = node_name[:i]
|
|
|
|
break
|
|
|
|
elif node_name[i] in ["1", "2", "3", "4", "5", "6", "7", "8", "9", "0"]:
|
|
|
|
continue
|
|
|
|
else:
|
|
|
|
break
|
|
|
|
return node_name
|
|
|
|
|
|
|
|
|
|
|
|
def find_tensor_node(node_list: List[Node]) -> List[Node]:
|
|
|
|
"""
|
|
|
|
find tensor nodes from a node list
|
|
|
|
"""
|
|
|
|
out = []
|
|
|
|
for node in node_list:
|
|
|
|
if get_node_shape(node) is not None:
|
|
|
|
out.append(node)
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
def find_tensor_shape_node(node_list: List[Node]) -> List[Node]:
|
|
|
|
"""
|
|
|
|
find tensor and shape nodes from a node list
|
|
|
|
"""
|
|
|
|
out = []
|
|
|
|
for node in node_list:
|
|
|
|
if get_node_shape(node) is not None:
|
|
|
|
out.append(node)
|
|
|
|
elif len(node.meta['fwd_out']) > 0 and isinstance(node.meta['fwd_out'], list) and isinstance(
|
|
|
|
node.meta['fwd_out'][0], int):
|
|
|
|
out.append(node)
|
|
|
|
return out
|