mirror of https://github.com/hpcaitech/ColossalAI
[autockpt] considering parameter and optimizer weights. (#2279)
* [autockpt] make it work. * [autockpt] linearize / merge shape-consistency nodes. * [autockpt] considering parameter and optimizer weights.pull/2281/head^2
parent
b0d21d0c4f
commit
8e8900ff3f
|
@ -35,10 +35,11 @@ class CheckpointSolverBase(ABC):
|
|||
free_memory: float = -1.0,
|
||||
requires_linearize: bool = False,
|
||||
cnode: List[str] = None,
|
||||
optim_multiplier: float = 1.0,
|
||||
):
|
||||
"""CheckpointSolver class will integrate information provided by the components
|
||||
and use an existing solver to find a possible optimal strategies combination for
|
||||
target computing graph.
|
||||
"""``CheckpointSolverBase`` class will integrate information provided by the components
|
||||
and use an existing solver to find a possible optimal strategies combination for target
|
||||
computing graph.
|
||||
|
||||
Existing Solvers:
|
||||
Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen)
|
||||
|
@ -49,9 +50,11 @@ class CheckpointSolverBase(ABC):
|
|||
free_memory (float): Memory constraint for the solution.
|
||||
requires_linearize (bool): Whether the graph needs to be linearized.
|
||||
cnode (List[str], optional): Common node List, should be the subset of input. Default to None.
|
||||
optim_multiplier (float, optional): The multiplier of extra weight storage for the
|
||||
``torch.optim.Optimizer``. Default to 1.0.
|
||||
|
||||
Warnings:
|
||||
`MetaInfoProp` should be done before constructing the solver. Meta information of the graph is required.
|
||||
Meta information of the graph is required for any ``CheckpointSolver``.
|
||||
"""
|
||||
# super-dainiu: this graph is a temporary graph which can refer to
|
||||
# the owning module, but we will return another deepcopy of it after
|
||||
|
@ -61,13 +64,14 @@ class CheckpointSolverBase(ABC):
|
|||
_copy_output(graph, self.graph)
|
||||
self.graph.set_codegen(ActivationCheckpointCodeGen())
|
||||
|
||||
# check if `MetaInfoProp` is done
|
||||
# check if has meta information
|
||||
if any(len(node.meta) == 0 for node in self.graph.nodes):
|
||||
raise RuntimeError(
|
||||
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!")
|
||||
"Nodes meta information hasn't been prepared! Please extract from graph before constructing the solver!"
|
||||
)
|
||||
|
||||
self.free_memory = free_memory
|
||||
self.parameter_size = _get_param_size(self.graph.owning_module)
|
||||
# parameter memory = parameter size + optimizer extra weight storage
|
||||
self.free_memory = free_memory - _get_param_size(self.graph.owning_module) * (optim_multiplier + 1)
|
||||
self.cnode = cnode
|
||||
self.requires_linearize = requires_linearize
|
||||
if self.requires_linearize:
|
||||
|
@ -97,7 +101,7 @@ class CheckpointSolverBase(ABC):
|
|||
the actual 'node' in linearized manner.
|
||||
|
||||
Remarks:
|
||||
Do merge the inplace ops into the previous node.
|
||||
Do merge the inplace ops and shape-consistency ops into the previous node.
|
||||
"""
|
||||
|
||||
# Common nodes are type of nodes that could be seen as attributes and remain
|
||||
|
@ -136,7 +140,7 @@ class CheckpointSolverBase(ABC):
|
|||
"""
|
||||
|
||||
def _is_inplace(n: Node):
|
||||
"""Get the inplace argument from torch.fx.Node
|
||||
"""Get the inplace argument from ``torch.fx.Node``
|
||||
"""
|
||||
inplace = False
|
||||
if n.op == "call_function":
|
||||
|
|
|
@ -19,9 +19,9 @@ class CheckpointSolverChen(CheckpointSolverBase):
|
|||
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
|
||||
|
||||
Usage:
|
||||
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
|
||||
Assume that we have a ``GraphModule``, and we have already done the extractions
|
||||
to the graph to retrieve all information needed, then we could use the following
|
||||
code to find a solution using `CheckpointSolverChen`:
|
||||
code to find a solution using ``CheckpointSolverChen``:
|
||||
>>> solver = CheckpointSolverChen(gm.graph)
|
||||
>>> chen_graph = solver.solve()
|
||||
>>> gm.graph = chen_graph # set the graph to a new graph
|
||||
|
@ -74,7 +74,7 @@ class CheckpointSolverChen(CheckpointSolverBase):
|
|||
def grid_search(self) -> Set:
|
||||
"""
|
||||
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy.
|
||||
Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A.
|
||||
Grid search over [√2/2 b, √2 b] for ``ckpt_opt`` over ``num_grids`` as in appendix A.
|
||||
"""
|
||||
_, b_approx = self.run_chen_greedy(0)
|
||||
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
|
||||
|
|
|
@ -23,15 +23,20 @@ __all__ = ['CheckpointSolverRotor']
|
|||
|
||||
class CheckpointSolverRotor(CheckpointSolverBase):
|
||||
|
||||
def __init__(self, graph: Graph, free_memory: float = -1, cnode: List[str] = None, memory_slots: int = 500):
|
||||
def __init__(self,
|
||||
graph: Graph,
|
||||
free_memory: float = -1,
|
||||
cnode: List[str] = None,
|
||||
memory_slots: int = 500,
|
||||
optim_multiplier: float = 1.0):
|
||||
"""This is the simple implementation of dynamic programming algorithm rotor
|
||||
in https://hal.inria.fr/hal-02352969. Some code are adapted from
|
||||
https://gitlab.inria.fr/hiepacs/rotor.
|
||||
|
||||
Usage:
|
||||
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
|
||||
Assume that we have a ``GraphModule``, and we have already done the extractions
|
||||
to the graph to retrieve all information needed, then we could use the following
|
||||
code to find a solution using `CheckpointSolverRotor`:
|
||||
code to find a solution using ``CheckpointSolverRotor``:
|
||||
>>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0])
|
||||
>>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver
|
||||
>>> gm.graph = rotor_graph # set the graph to a new graph
|
||||
|
@ -42,6 +47,8 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||
Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1.
|
||||
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
|
||||
memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
|
||||
optim_multiplier (float, optional): The multiplier of extra weight storage for the
|
||||
``torch.optim.Optimizer``. Default to 1.0.
|
||||
"""
|
||||
super().__init__(graph, free_memory, True, cnode)
|
||||
self.memory_slots = memory_slots
|
||||
|
@ -298,8 +305,8 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||
lhs (int): The left index of the interval to backtrack.
|
||||
rhs (int): The right index of the interval to backtrack.
|
||||
budget (int): The memory budget for processing this interval.
|
||||
cost_table (List[Any]): See `._compute_table()` for definitions
|
||||
back_ptr (List[Any]): See `._compute_table()` for definitions
|
||||
cost_table (List[Any]): See ``._compute_table()`` for definitions
|
||||
back_ptr (List[Any]): See ``._compute_table()`` for definitions
|
||||
|
||||
Raises:
|
||||
ValueError: Can not process the chain.
|
||||
|
@ -340,7 +347,7 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
|||
|
||||
@staticmethod
|
||||
def _annotate_from_sequence(sequence: Sequence, node_list: List[List[Node]]):
|
||||
"""Annotate the nodes in the node_list with activation checkpoint from the sequence.
|
||||
"""Annotate the nodes in the ``node_list`` with activation checkpoint from the sequence.
|
||||
|
||||
Args:
|
||||
sequence (Sequence): The sequence of executing nodes with activation checkpoint annotations.
|
||||
|
|
Loading…
Reference in New Issue