from abc import ABC, abstractmethod from copy import deepcopy from typing import Any, List import torch from torch.fx import Graph, Node from colossalai.auto_parallel.passes.runtime_apply_pass import ( runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply, ) from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen __all___ = ['CheckpointSolverBase'] def _copy_output(src: Graph, dst: Graph): """Copy the output node from src to dst""" for n_src, n_dst in zip(src.nodes, dst.nodes): if n_src.op == 'output': n_dst.meta = n_src.meta def _get_param_size(module: torch.nn.Module): """Get the size of the parameters in the module""" return sum([p.numel() * torch.tensor([], dtype=p.dtype).element_size() for p in module.parameters()]) class CheckpointSolverBase(ABC): def __init__( self, graph: Graph, free_memory: float = -1.0, requires_linearize: bool = False, cnode: List[str] = None, ): """CheckpointSolver class will integrate information provided by the components and use an existing solver to find a possible optimal strategies combination for target computing graph. Existing Solvers: Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen) Rotor solver: https://hal.inria.fr/hal-02352969 (CheckpointSolverRotor) Args: graph (Graph): The computing graph to be optimized. free_memory (float): Memory constraint for the solution. requires_linearize (bool): Whether the graph needs to be linearized. cnode (List[str], optional): Common node List, should be the subset of input. Default to None. Warnings: `MetaInfoProp` should be done before constructing the solver. Meta information of the graph is required. """ # super-dainiu: this graph is a temporary graph which can refer to # the owning module, but we will return another deepcopy of it after # the solver is executed. self.graph = deepcopy(graph) self.graph.owning_module = graph.owning_module _copy_output(graph, self.graph) self.graph.set_codegen(ActivationCheckpointCodeGen()) # check if `MetaInfoProp` is done if any(len(node.meta) == 0 for node in self.graph.nodes): raise RuntimeError( "Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!") self.free_memory = free_memory self.parameter_size = _get_param_size(self.graph.owning_module) self.cnode = cnode self.requires_linearize = requires_linearize if self.requires_linearize: self.node_list = self._linearize_graph() else: self.node_list = self.get_node_list() @abstractmethod def solve(self): """Solve the checkpointing problem and return the solution. """ pass def get_node_list(self): """Get the node list. """ return [[node] for node in self.graph.nodes] def _linearize_graph(self) -> List[List[Node]]: """Linearizing the graph Args: graph (Graph): The computing graph to be optimized. Returns: List[List[Node]]: List of list, each inside list of Node presents the actual 'node' in linearized manner. Remarks: Do merge the inplace ops into the previous node. """ # Common nodes are type of nodes that could be seen as attributes and remain # unchanged throughout the whole model, it will be used several times by # different blocks of model, so that it is hard for us to linearize the graph # when we encounter those kinds of nodes. We let users to annotate some of the # input as common node, such as attention mask, and the followings are some of # the ops that could actually be seen as common nodes. With our common node prop, # we could find some of the "real" common nodes (e.g. the real attention mask # used in BERT and GPT), the rule is simple, for node who's parents are all common # nodes or it's op belongs to the following operations, we view this node as a # newly born common node. # List of target name that could be seen as common node common_ops = ["getattr", "getitem", "size"] def _is_cop(target: Any) -> bool: """Check if an op could be seen as common node Args: target (Any): node target Returns: bool """ if isinstance(target, str): return target in common_ops else: return target.__name__ in common_ops def _is_sink() -> bool: """Check if we can free all dependencies Returns: bool """ def _is_inplace(n: Node): """Get the inplace argument from torch.fx.Node """ inplace = False if n.op == "call_function": inplace = n.kwargs.get("inplace", False) elif n.op == "call_module": inplace = getattr(n.graph.owning_module.get_submodule(n.target), "inplace", False) return inplace def _is_shape_consistency(n: Node): """Check if this node is shape-consistency node (i.e. ``runtime_apply`` or ``runtime_apply_for_iterable_object``) """ return n.target in [runtime_apply, runtime_apply_for_iterable_object, runtime_comm_spec_apply] return not sum([v for _, v in deps.items()]) and not any(map(_is_inplace, n.users)) and not any( map(_is_shape_consistency, n.users)) # make sure that item in cnode is valid if self.cnode: for name in self.cnode: try: assert next(node for node in self.graph.nodes if node.name == name).op == "placeholder", \ f"Common node {name} is not an input of the model." except StopIteration: raise ValueError(f"Common node name {name} not in graph.") else: self.cnode = [] deps = {} node_list = [] region = [] for n in self.graph.nodes: if n.op != "placeholder" and n.op != "output": for n_par in n.all_input_nodes: if n_par.op != "placeholder" and n_par.name not in self.cnode: deps[n_par] -= 1 region.append(n) # if the node could free all dependencies in graph # we could begin a new node if _is_sink(): node_list.append(region) region = [] # propagate common node attr if possible if len(n.all_input_nodes) == len([node for node in n.all_input_nodes if node.name in self.cnode ]) or _is_cop(n.target): self.cnode.append(n.name) else: deps[n] = len([user for user in n.users if user.op != "output"]) return node_list