You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/autochunk/utils.py

253 lines
7.0 KiB

from typing import Any, Dict, List, Union
from torch.fx.node import Node
from colossalai.logging import get_dist_logger
NON_COMPUTE_OP = ["placeholder", "get_attr", "output"]
NON_COMPUTE_NAME = ["getattr", "eq", "_assert_is_none", "_assert", "finfo", "size"]
logger = get_dist_logger()
class NodeMgr(object):
def __init__(self, nodes_list: List[Node]) -> None:
self._node_list = nodes_list
self._node_dict = {}
self._set_node_dict()
def _set_node_dict(self) -> None:
"""
create a dict {node_name: node_idx}
"""
self._node_dict.clear()
for idx, node in enumerate(self._node_list):
self._node_dict[node.name] = idx
def find_node_idx(self, node: Node) -> int:
"""
find node's index
"""
return self._node_dict[node.name]
def find_node_idx_by_name(self, node_name: str) -> int:
"""
find node's index
"""
return self._node_dict[node_name]
def get_node_by_idx(self, idx: int) -> Node:
"""
get a node by index
"""
return self._node_list[idx]
def get_node_slice_by_idx(self, start: int, end: int) -> List[Node]:
"""
get a slice of node by index
"""
return self._node_list[start:end]
def get_node_list(self) -> List:
"""
get full node list
"""
return self._node_list
def update_node_list(self, node_list: List) -> None:
"""
update node list, reset node dict
"""
self._node_list = node_list
self._set_node_dict()
def get_logger() -> Any:
return logger
def flat_list(inputs: Any) -> List:
"""
flat a list by recursion
"""
if not (isinstance(inputs, list) or isinstance(inputs, set) or isinstance(inputs, tuple)):
return [inputs]
res = []
for i in inputs:
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
res.extend(flat_list(i))
elif isinstance(i, dict):
res.extend(flat_list(list(i.keys())))
else:
res.append(i)
return res
def find_first_tensor_arg(node: Node) -> Node:
"""
Find the first input tensor arg for a node
"""
for arg in node.args:
if type(arg) == type(node):
return arg
raise RuntimeError()
def is_non_compute_node(node: Node) -> bool:
if any(i == node.op for i in NON_COMPUTE_OP) or any(i == get_node_name(node) for i in NON_COMPUTE_NAME):
return True
if "getitem" in node.name:
if get_node_shape(node) is not None:
return False
node_args = flat_list(node.args[1:])
for node_arg in node_args:
if any(i == str(node_arg) for i in ["None", "Ellipsis"]):
return False
if "slice" in str(node_arg):
return False
return True
return False
def get_node_shape(node: Node) -> Any:
"""
return node data shape
"""
if get_node_name(node) in ["split", "unbind"]:
return node.meta["tensor_meta"][0].shape
if hasattr(node.meta["tensor_meta"], "shape"):
return node.meta["tensor_meta"].shape
return None
def is_non_memory_node(node: Node) -> bool:
if "getitem" in node.name:
return True
if "output" in node.op:
return True
return is_non_compute_node(node)
def is_non_compute_node_except_placeholder(node: Node) -> bool:
if "placeholder" in node.op:
return False
return is_non_compute_node(node)
def is_non_compute_node_except_placeholder_output(node: Node) -> bool:
if "output" in node.op:
return False
return is_non_compute_node_except_placeholder(node)
def delete_free_var_from_last_use(user_to_last_uses: Dict) -> None:
for key, value in user_to_last_uses.items():
for n in value:
if n.op == "placeholder":
user_to_last_uses[key].remove(n)
def find_chunk_all_input_nodes(nodes: List[Node]) -> List:
"""
Find non-compute input and output node names.
input nodes are nodes used in the list
output nodes are nodes will use nodes in the list
"""
input_nodes = []
for node in nodes:
for input_node in node._input_nodes.keys():
if input_node not in nodes and input_node not in input_nodes:
input_nodes.append(input_node)
return input_nodes
def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List, List]:
"""
Find non-compute input and output node names.
input nodes are nodes used in the list
output nodes are nodes will use nodes in the list
"""
input_nodes = []
output_nodes = []
# if a node has an input node which is not in the node list
# we treat that input node as the input of the checkpoint function
for node in nodes:
for input_node in node._input_nodes.keys():
if (
input_node not in nodes
and input_node not in input_nodes
and not is_non_compute_node_except_placeholder(input_node)
):
input_nodes.append(input_node)
# if a node has a user node which is not in the node list
# we treat that user node as the node receiving the current node output
for node in nodes:
for output_node in node.users.keys():
if (
output_node not in nodes
and node not in output_nodes
and not is_non_compute_node_except_placeholder_output(output_node)
):
output_nodes.append(node)
return input_nodes, output_nodes
def get_module_node_name(node: Node) -> str:
"""
get module class name
"""
node_targets = node.target.split(".")
module = node.graph.owning_module
for i in node_targets:
module = getattr(module, i)
module_name = str(module.__class__).split(".")[-1][:-2]
module_name = module_name.lower()
return module_name
def get_node_name(node: Node) -> str:
"""
get node name
"""
node_name = node.name
if "_" in node_name:
for i in range(len(node_name) - 1, -1, -1):
if node_name[i] == "_":
node_name = node_name[:i]
break
elif node_name[i] in ["1", "2", "3", "4", "5", "6", "7", "8", "9", "0"]:
continue
else:
break
return node_name
def find_tensor_node(node_list: List[Node]) -> List[Node]:
"""
find tensor nodes from a node list
"""
out = []
for node in node_list:
if get_node_shape(node) is not None:
out.append(node)
return out
def find_tensor_shape_node(node_list: List[Node]) -> List[Node]:
"""
find tensor and shape nodes from a node list
"""
out = []
for node in node_list:
if get_node_shape(node) is not None:
out.append(node)
elif (
len(node.meta["fwd_out"]) > 0
and isinstance(node.meta["fwd_out"], list)
and isinstance(node.meta["fwd_out"][0], int)
):
out.append(node)
return out