[autockpt] considering parameter and optimizer weights. (#2279)

* [autockpt] make it work.

* [autockpt] linearize / merge shape-consistency nodes.

* [autockpt] considering parameter and optimizer weights.
pull/2281/head^2
Super Daniel 2023-01-03 16:55:49 +08:00 committed by GitHub
parent b0d21d0c4f
commit 8e8900ff3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 30 additions and 19 deletions

View File

@ -35,10 +35,11 @@ class CheckpointSolverBase(ABC):
free_memory: float = -1.0, free_memory: float = -1.0,
requires_linearize: bool = False, requires_linearize: bool = False,
cnode: List[str] = None, cnode: List[str] = None,
optim_multiplier: float = 1.0,
): ):
"""CheckpointSolver class will integrate information provided by the components """``CheckpointSolverBase`` class will integrate information provided by the components
and use an existing solver to find a possible optimal strategies combination for and use an existing solver to find a possible optimal strategies combination for target
target computing graph. computing graph.
Existing Solvers: Existing Solvers:
Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen) Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen)
@ -49,9 +50,11 @@ class CheckpointSolverBase(ABC):
free_memory (float): Memory constraint for the solution. free_memory (float): Memory constraint for the solution.
requires_linearize (bool): Whether the graph needs to be linearized. requires_linearize (bool): Whether the graph needs to be linearized.
cnode (List[str], optional): Common node List, should be the subset of input. Default to None. cnode (List[str], optional): Common node List, should be the subset of input. Default to None.
optim_multiplier (float, optional): The multiplier of extra weight storage for the
``torch.optim.Optimizer``. Default to 1.0.
Warnings: Warnings:
`MetaInfoProp` should be done before constructing the solver. Meta information of the graph is required. Meta information of the graph is required for any ``CheckpointSolver``.
""" """
# super-dainiu: this graph is a temporary graph which can refer to # super-dainiu: this graph is a temporary graph which can refer to
# the owning module, but we will return another deepcopy of it after # the owning module, but we will return another deepcopy of it after
@ -61,13 +64,14 @@ class CheckpointSolverBase(ABC):
_copy_output(graph, self.graph) _copy_output(graph, self.graph)
self.graph.set_codegen(ActivationCheckpointCodeGen()) self.graph.set_codegen(ActivationCheckpointCodeGen())
# check if `MetaInfoProp` is done # check if has meta information
if any(len(node.meta) == 0 for node in self.graph.nodes): if any(len(node.meta) == 0 for node in self.graph.nodes):
raise RuntimeError( raise RuntimeError(
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!") "Nodes meta information hasn't been prepared! Please extract from graph before constructing the solver!"
)
self.free_memory = free_memory # parameter memory = parameter size + optimizer extra weight storage
self.parameter_size = _get_param_size(self.graph.owning_module) self.free_memory = free_memory - _get_param_size(self.graph.owning_module) * (optim_multiplier + 1)
self.cnode = cnode self.cnode = cnode
self.requires_linearize = requires_linearize self.requires_linearize = requires_linearize
if self.requires_linearize: if self.requires_linearize:
@ -97,7 +101,7 @@ class CheckpointSolverBase(ABC):
the actual 'node' in linearized manner. the actual 'node' in linearized manner.
Remarks: Remarks:
Do merge the inplace ops into the previous node. Do merge the inplace ops and shape-consistency ops into the previous node.
""" """
# Common nodes are type of nodes that could be seen as attributes and remain # Common nodes are type of nodes that could be seen as attributes and remain
@ -136,7 +140,7 @@ class CheckpointSolverBase(ABC):
""" """
def _is_inplace(n: Node): def _is_inplace(n: Node):
"""Get the inplace argument from torch.fx.Node """Get the inplace argument from ``torch.fx.Node``
""" """
inplace = False inplace = False
if n.op == "call_function": if n.op == "call_function":

View File

@ -19,9 +19,9 @@ class CheckpointSolverChen(CheckpointSolverBase):
Note that this algorithm targets at memory optimization only, using techniques in appendix A. Note that this algorithm targets at memory optimization only, using techniques in appendix A.
Usage: Usage:
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp` Assume that we have a ``GraphModule``, and we have already done the extractions
to the graph to retrieve all information needed, then we could use the following to the graph to retrieve all information needed, then we could use the following
code to find a solution using `CheckpointSolverChen`: code to find a solution using ``CheckpointSolverChen``:
>>> solver = CheckpointSolverChen(gm.graph) >>> solver = CheckpointSolverChen(gm.graph)
>>> chen_graph = solver.solve() >>> chen_graph = solver.solve()
>>> gm.graph = chen_graph # set the graph to a new graph >>> gm.graph = chen_graph # set the graph to a new graph
@ -74,7 +74,7 @@ class CheckpointSolverChen(CheckpointSolverBase):
def grid_search(self) -> Set: def grid_search(self) -> Set:
""" """
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = xy. Search ckpt strategy with b = 0, then run the allocation algorithm again with b = xy.
Grid search over [2/2 b, 2 b] for ckpt_opt over num_grids as in appendix A. Grid search over [2/2 b, 2 b] for ``ckpt_opt`` over ``num_grids`` as in appendix A.
""" """
_, b_approx = self.run_chen_greedy(0) _, b_approx = self.run_chen_greedy(0)
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2)) b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))

View File

@ -23,15 +23,20 @@ __all__ = ['CheckpointSolverRotor']
class CheckpointSolverRotor(CheckpointSolverBase): class CheckpointSolverRotor(CheckpointSolverBase):
def __init__(self, graph: Graph, free_memory: float = -1, cnode: List[str] = None, memory_slots: int = 500): def __init__(self,
graph: Graph,
free_memory: float = -1,
cnode: List[str] = None,
memory_slots: int = 500,
optim_multiplier: float = 1.0):
"""This is the simple implementation of dynamic programming algorithm rotor """This is the simple implementation of dynamic programming algorithm rotor
in https://hal.inria.fr/hal-02352969. Some code are adapted from in https://hal.inria.fr/hal-02352969. Some code are adapted from
https://gitlab.inria.fr/hiepacs/rotor. https://gitlab.inria.fr/hiepacs/rotor.
Usage: Usage:
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp` Assume that we have a ``GraphModule``, and we have already done the extractions
to the graph to retrieve all information needed, then we could use the following to the graph to retrieve all information needed, then we could use the following
code to find a solution using `CheckpointSolverRotor`: code to find a solution using ``CheckpointSolverRotor``:
>>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0]) >>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0])
>>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver >>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver
>>> gm.graph = rotor_graph # set the graph to a new graph >>> gm.graph = rotor_graph # set the graph to a new graph
@ -42,6 +47,8 @@ class CheckpointSolverRotor(CheckpointSolverBase):
Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1. Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1.
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None. cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500. memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
optim_multiplier (float, optional): The multiplier of extra weight storage for the
``torch.optim.Optimizer``. Default to 1.0.
""" """
super().__init__(graph, free_memory, True, cnode) super().__init__(graph, free_memory, True, cnode)
self.memory_slots = memory_slots self.memory_slots = memory_slots
@ -298,8 +305,8 @@ class CheckpointSolverRotor(CheckpointSolverBase):
lhs (int): The left index of the interval to backtrack. lhs (int): The left index of the interval to backtrack.
rhs (int): The right index of the interval to backtrack. rhs (int): The right index of the interval to backtrack.
budget (int): The memory budget for processing this interval. budget (int): The memory budget for processing this interval.
cost_table (List[Any]): See `._compute_table()` for definitions cost_table (List[Any]): See ``._compute_table()`` for definitions
back_ptr (List[Any]): See `._compute_table()` for definitions back_ptr (List[Any]): See ``._compute_table()`` for definitions
Raises: Raises:
ValueError: Can not process the chain. ValueError: Can not process the chain.
@ -340,7 +347,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
@staticmethod @staticmethod
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]): def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
"""Annotate the nodes in the node_list with activation checkpoint from the sequence. """Annotate the nodes in the ``node_list`` with activation checkpoint from the sequence.
Args: Args:
sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations. sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations.