From 8e8900ff3f7894ecdc6c6aa5d705ebe2eb983c5c Mon Sep 17 00:00:00 2001 From: Super Daniel <78588128+super-dainiu@users.noreply.github.com> Date: Tue, 3 Jan 2023 16:55:49 +0800 Subject: [PATCH] [autockpt] considering parameter and optimizer weights. (#2279) * [autockpt] make it work. * [autockpt] linearize / merge shape-consistency nodes. * [autockpt] considering parameter and optimizer weights. --- .../checkpoint/ckpt_solver_base.py | 24 +++++++++++-------- .../checkpoint/ckpt_solver_chen.py | 6 ++--- .../checkpoint/ckpt_solver_rotor.py | 19 ++++++++++----- 3 files changed, 30 insertions(+), 19 deletions(-) diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py index ecccef8d7..b388d00ac 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_base.py @@ -35,10 +35,11 @@ class CheckpointSolverBase(ABC): free_memory: float = -1.0, requires_linearize: bool = False, cnode: List[str] = None, + optim_multiplier: float = 1.0, ): - """CheckpointSolver class will integrate information provided by the components - and use an existing solver to find a possible optimal strategies combination for - target computing graph. + """``CheckpointSolverBase`` class will integrate information provided by the components + and use an existing solver to find a possible optimal strategies combination for target + computing graph. Existing Solvers: Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen) @@ -49,9 +50,11 @@ class CheckpointSolverBase(ABC): free_memory (float): Memory constraint for the solution. requires_linearize (bool): Whether the graph needs to be linearized. cnode (List[str], optional): Common node List, should be the subset of input. Default to None. + optim_multiplier (float, optional): The multiplier of extra weight storage for the + ``torch.optim.Optimizer``. Default to 1.0. Warnings: - `MetaInfoProp` should be done before constructing the solver. Meta information of the graph is required. + Meta information of the graph is required for any ``CheckpointSolver``. """ # super-dainiu: this graph is a temporary graph which can refer to # the owning module, but we will return another deepcopy of it after @@ -61,13 +64,14 @@ class CheckpointSolverBase(ABC): _copy_output(graph, self.graph) self.graph.set_codegen(ActivationCheckpointCodeGen()) - # check if `MetaInfoProp` is done + # check if has meta information if any(len(node.meta) == 0 for node in self.graph.nodes): raise RuntimeError( - "Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!") + "Nodes meta information hasn't been prepared! Please extract from graph before constructing the solver!" + ) - self.free_memory = free_memory - self.parameter_size = _get_param_size(self.graph.owning_module) + # parameter memory = parameter size + optimizer extra weight storage + self.free_memory = free_memory - _get_param_size(self.graph.owning_module) * (optim_multiplier + 1) self.cnode = cnode self.requires_linearize = requires_linearize if self.requires_linearize: @@ -97,7 +101,7 @@ class CheckpointSolverBase(ABC): the actual 'node' in linearized manner. Remarks: - Do merge the inplace ops into the previous node. + Do merge the inplace ops and shape-consistency ops into the previous node. """ # Common nodes are type of nodes that could be seen as attributes and remain @@ -136,7 +140,7 @@ class CheckpointSolverBase(ABC): """ def _is_inplace(n: Node): - """Get the inplace argument from torch.fx.Node + """Get the inplace argument from ``torch.fx.Node`` """ inplace = False if n.op == "call_function": diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py index 58878253e..19b2ef598 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_chen.py @@ -19,9 +19,9 @@ class CheckpointSolverChen(CheckpointSolverBase): Note that this algorithm targets at memory optimization only, using techniques in appendix A. Usage: - Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp` + Assume that we have a ``GraphModule``, and we have already done the extractions to the graph to retrieve all information needed, then we could use the following - code to find a solution using `CheckpointSolverChen`: + code to find a solution using ``CheckpointSolverChen``: >>> solver = CheckpointSolverChen(gm.graph) >>> chen_graph = solver.solve() >>> gm.graph = chen_graph # set the graph to a new graph @@ -74,7 +74,7 @@ class CheckpointSolverChen(CheckpointSolverBase): def grid_search(self) -> Set: """ Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy. - Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A. + Grid search over [√2/2 b, √2 b] for ``ckpt_opt`` over ``num_grids`` as in appendix A. """ _, b_approx = self.run_chen_greedy(0) b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2)) diff --git a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py index cd5b70d11..5cc57fca0 100644 --- a/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py +++ b/colossalai/auto_parallel/checkpoint/ckpt_solver_rotor.py @@ -23,15 +23,20 @@ __all__ = ['CheckpointSolverRotor'] class CheckpointSolverRotor(CheckpointSolverBase): - def __init__(self, graph: Graph, free_memory: float = -1, cnode: List[str] = None, memory_slots: int = 500): + def __init__(self, + graph: Graph, + free_memory: float = -1, + cnode: List[str] = None, + memory_slots: int = 500, + optim_multiplier: float = 1.0): """This is the simple implementation of dynamic programming algorithm rotor in https://hal.inria.fr/hal-02352969. Some code are adapted from https://gitlab.inria.fr/hiepacs/rotor. Usage: - Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp` + Assume that we have a ``GraphModule``, and we have already done the extractions to the graph to retrieve all information needed, then we could use the following - code to find a solution using `CheckpointSolverRotor`: + code to find a solution using ``CheckpointSolverRotor``: >>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0]) >>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver >>> gm.graph = rotor_graph # set the graph to a new graph @@ -42,6 +47,8 @@ class CheckpointSolverRotor(CheckpointSolverBase): Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1. cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None. memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500. + optim_multiplier (float, optional): The multiplier of extra weight storage for the + ``torch.optim.Optimizer``. Default to 1.0. """ super().__init__(graph, free_memory, True, cnode) self.memory_slots = memory_slots @@ -298,8 +305,8 @@ class CheckpointSolverRotor(CheckpointSolverBase): lhs (int): The left index of the interval to backtrack. rhs (int): The right index of the interval to backtrack. budget (int): The memory budget for processing this interval. - cost_table (List[Any]): See `._compute_table()` for definitions - back_ptr (List[Any]): See `._compute_table()` for definitions + cost_table (List[Any]): See ``._compute_table()`` for definitions + back_ptr (List[Any]): See ``._compute_table()`` for definitions Raises: ValueError: Can not process the chain. @@ -340,7 +347,7 @@ class CheckpointSolverRotor(CheckpointSolverBase): @staticmethod def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): - """Annotate the nodes in the node_list with activation checkpoint from the sequence. + """Annotate the nodes in the ``node_list`` with activation checkpoint from the sequence. Args: sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations.