mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
252 lines
7.0 KiB
252 lines
7.0 KiB
from typing import Any, Dict, List, Union |
|
|
|
from torch.fx.node import Node |
|
|
|
from colossalai.logging import get_dist_logger |
|
|
|
NON_COMPUTE_OP = ["placeholder", "get_attr", "output"] |
|
NON_COMPUTE_NAME = ["getattr", "eq", "_assert_is_none", "_assert", "finfo", "size"] |
|
logger = get_dist_logger() |
|
|
|
|
|
class NodeMgr(object): |
|
def __init__(self, nodes_list: List[Node]) -> None: |
|
self._node_list = nodes_list |
|
self._node_dict = {} |
|
self._set_node_dict() |
|
|
|
def _set_node_dict(self) -> None: |
|
""" |
|
create a dict {node_name: node_idx} |
|
""" |
|
self._node_dict.clear() |
|
for idx, node in enumerate(self._node_list): |
|
self._node_dict[node.name] = idx |
|
|
|
def find_node_idx(self, node: Node) -> int: |
|
""" |
|
find node's index |
|
""" |
|
return self._node_dict[node.name] |
|
|
|
def find_node_idx_by_name(self, node_name: str) -> int: |
|
""" |
|
find node's index |
|
""" |
|
return self._node_dict[node_name] |
|
|
|
def get_node_by_idx(self, idx: int) -> Node: |
|
""" |
|
get a node by index |
|
""" |
|
return self._node_list[idx] |
|
|
|
def get_node_slice_by_idx(self, start: int, end: int) -> List[Node]: |
|
""" |
|
get a slice of node by index |
|
""" |
|
return self._node_list[start:end] |
|
|
|
def get_node_list(self) -> List: |
|
""" |
|
get full node list |
|
""" |
|
return self._node_list |
|
|
|
def update_node_list(self, node_list: List) -> None: |
|
""" |
|
update node list, reset node dict |
|
""" |
|
self._node_list = node_list |
|
self._set_node_dict() |
|
|
|
|
|
def get_logger() -> Any: |
|
return logger |
|
|
|
|
|
def flat_list(inputs: Any) -> List: |
|
""" |
|
flat a list by recursion |
|
""" |
|
if not (isinstance(inputs, list) or isinstance(inputs, set) or isinstance(inputs, tuple)): |
|
return [inputs] |
|
res = [] |
|
for i in inputs: |
|
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple): |
|
res.extend(flat_list(i)) |
|
elif isinstance(i, dict): |
|
res.extend(flat_list(list(i.keys()))) |
|
else: |
|
res.append(i) |
|
return res |
|
|
|
|
|
def find_first_tensor_arg(node: Node) -> Node: |
|
""" |
|
Find the first input tensor arg for a node |
|
""" |
|
for arg in node.args: |
|
if type(arg) == type(node): |
|
return arg |
|
raise RuntimeError() |
|
|
|
|
|
def is_non_compute_node(node: Node) -> bool: |
|
if any(i == node.op for i in NON_COMPUTE_OP) or any(i == get_node_name(node) for i in NON_COMPUTE_NAME): |
|
return True |
|
if "getitem" in node.name: |
|
if get_node_shape(node) is not None: |
|
return False |
|
node_args = flat_list(node.args[1:]) |
|
for node_arg in node_args: |
|
if any(i == str(node_arg) for i in ["None", "Ellipsis"]): |
|
return False |
|
if "slice" in str(node_arg): |
|
return False |
|
return True |
|
return False |
|
|
|
|
|
def get_node_shape(node: Node) -> Any: |
|
""" |
|
return node data shape |
|
""" |
|
if get_node_name(node) in ["split", "unbind"]: |
|
return node.meta["tensor_meta"][0].shape |
|
if hasattr(node.meta["tensor_meta"], "shape"): |
|
return node.meta["tensor_meta"].shape |
|
return None |
|
|
|
|
|
def is_non_memory_node(node: Node) -> bool: |
|
if "getitem" in node.name: |
|
return True |
|
if "output" in node.op: |
|
return True |
|
return is_non_compute_node(node) |
|
|
|
|
|
def is_non_compute_node_except_placeholder(node: Node) -> bool: |
|
if "placeholder" in node.op: |
|
return False |
|
return is_non_compute_node(node) |
|
|
|
|
|
def is_non_compute_node_except_placeholder_output(node: Node) -> bool: |
|
if "output" in node.op: |
|
return False |
|
return is_non_compute_node_except_placeholder(node) |
|
|
|
|
|
def delete_free_var_from_last_use(user_to_last_uses: Dict) -> None: |
|
for key, value in user_to_last_uses.items(): |
|
for n in value: |
|
if n.op == "placeholder": |
|
user_to_last_uses[key].remove(n) |
|
|
|
|
|
def find_chunk_all_input_nodes(nodes: List[Node]) -> List: |
|
""" |
|
Find non-compute input and output node names. |
|
input nodes are nodes used in the list |
|
output nodes are nodes will use nodes in the list |
|
""" |
|
input_nodes = [] |
|
for node in nodes: |
|
for input_node in node._input_nodes.keys(): |
|
if input_node not in nodes and input_node not in input_nodes: |
|
input_nodes.append(input_node) |
|
return input_nodes |
|
|
|
|
|
def find_chunk_compute_input_and_output_nodes(nodes: List[Node]) -> Union[List, List]: |
|
""" |
|
Find non-compute input and output node names. |
|
input nodes are nodes used in the list |
|
output nodes are nodes will use nodes in the list |
|
""" |
|
input_nodes = [] |
|
output_nodes = [] |
|
|
|
# if a node has an input node which is not in the node list |
|
# we treat that input node as the input of the checkpoint function |
|
for node in nodes: |
|
for input_node in node._input_nodes.keys(): |
|
if ( |
|
input_node not in nodes |
|
and input_node not in input_nodes |
|
and not is_non_compute_node_except_placeholder(input_node) |
|
): |
|
input_nodes.append(input_node) |
|
|
|
# if a node has a user node which is not in the node list |
|
# we treat that user node as the node receiving the current node output |
|
for node in nodes: |
|
for output_node in node.users.keys(): |
|
if ( |
|
output_node not in nodes |
|
and node not in output_nodes |
|
and not is_non_compute_node_except_placeholder_output(output_node) |
|
): |
|
output_nodes.append(node) |
|
|
|
return input_nodes, output_nodes |
|
|
|
|
|
def get_module_node_name(node: Node) -> str: |
|
""" |
|
get module class name |
|
""" |
|
node_targets = node.target.split(".") |
|
module = node.graph.owning_module |
|
for i in node_targets: |
|
module = getattr(module, i) |
|
module_name = str(module.__class__).split(".")[-1][:-2] |
|
module_name = module_name.lower() |
|
return module_name |
|
|
|
|
|
def get_node_name(node: Node) -> str: |
|
""" |
|
get node name |
|
""" |
|
node_name = node.name |
|
if "_" in node_name: |
|
for i in range(len(node_name) - 1, -1, -1): |
|
if node_name[i] == "_": |
|
node_name = node_name[:i] |
|
break |
|
elif node_name[i] in ["1", "2", "3", "4", "5", "6", "7", "8", "9", "0"]: |
|
continue |
|
else: |
|
break |
|
return node_name |
|
|
|
|
|
def find_tensor_node(node_list: List[Node]) -> List[Node]: |
|
""" |
|
find tensor nodes from a node list |
|
""" |
|
out = [] |
|
for node in node_list: |
|
if get_node_shape(node) is not None: |
|
out.append(node) |
|
return out |
|
|
|
|
|
def find_tensor_shape_node(node_list: List[Node]) -> List[Node]: |
|
""" |
|
find tensor and shape nodes from a node list |
|
""" |
|
out = [] |
|
for node in node_list: |
|
if get_node_shape(node) is not None: |
|
out.append(node) |
|
elif ( |
|
len(node.meta["fwd_out"]) > 0 |
|
and isinstance(node.meta["fwd_out"], list) |
|
and isinstance(node.meta["fwd_out"][0], int) |
|
): |
|
out.append(node) |
|
return out
|
|
|