|
|
|
@ -35,10 +35,11 @@ class CheckpointSolverBase(ABC):
|
|
|
|
|
free_memory: float = -1.0, |
|
|
|
|
requires_linearize: bool = False, |
|
|
|
|
cnode: List[str] = None, |
|
|
|
|
optim_multiplier: float = 1.0, |
|
|
|
|
): |
|
|
|
|
"""CheckpointSolver class will integrate information provided by the components |
|
|
|
|
and use an existing solver to find a possible optimal strategies combination for |
|
|
|
|
target computing graph. |
|
|
|
|
"""``CheckpointSolverBase`` class will integrate information provided by the components |
|
|
|
|
and use an existing solver to find a possible optimal strategies combination for target |
|
|
|
|
computing graph. |
|
|
|
|
|
|
|
|
|
Existing Solvers: |
|
|
|
|
Chen's Greedy solver: https://arxiv.org/abs/1604.06174 (CheckpointSolverChen) |
|
|
|
@ -49,9 +50,11 @@ class CheckpointSolverBase(ABC):
|
|
|
|
|
free_memory (float): Memory constraint for the solution. |
|
|
|
|
requires_linearize (bool): Whether the graph needs to be linearized. |
|
|
|
|
cnode (List[str], optional): Common node List, should be the subset of input. Default to None. |
|
|
|
|
optim_multiplier (float, optional): The multiplier of extra weight storage for the |
|
|
|
|
``torch.optim.Optimizer``. Default to 1.0. |
|
|
|
|
|
|
|
|
|
Warnings: |
|
|
|
|
`MetaInfoProp` should be done before constructing the solver. Meta information of the graph is required. |
|
|
|
|
Meta information of the graph is required for any ``CheckpointSolver``. |
|
|
|
|
""" |
|
|
|
|
# super-dainiu: this graph is a temporary graph which can refer to |
|
|
|
|
# the owning module, but we will return another deepcopy of it after |
|
|
|
@ -61,13 +64,14 @@ class CheckpointSolverBase(ABC):
|
|
|
|
|
_copy_output(graph, self.graph) |
|
|
|
|
self.graph.set_codegen(ActivationCheckpointCodeGen()) |
|
|
|
|
|
|
|
|
|
# check if `MetaInfoProp` is done |
|
|
|
|
# check if has meta information |
|
|
|
|
if any(len(node.meta) == 0 for node in self.graph.nodes): |
|
|
|
|
raise RuntimeError( |
|
|
|
|
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!") |
|
|
|
|
"Nodes meta information hasn't been prepared! Please extract from graph before constructing the solver!" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
self.free_memory = free_memory |
|
|
|
|
self.parameter_size = _get_param_size(self.graph.owning_module) |
|
|
|
|
# parameter memory = parameter size + optimizer extra weight storage |
|
|
|
|
self.free_memory = free_memory - _get_param_size(self.graph.owning_module) * (optim_multiplier + 1) |
|
|
|
|
self.cnode = cnode |
|
|
|
|
self.requires_linearize = requires_linearize |
|
|
|
|
if self.requires_linearize: |
|
|
|
@ -97,7 +101,7 @@ class CheckpointSolverBase(ABC):
|
|
|
|
|
the actual 'node' in linearized manner. |
|
|
|
|
|
|
|
|
|
Remarks: |
|
|
|
|
Do merge the inplace ops into the previous node. |
|
|
|
|
Do merge the inplace ops and shape-consistency ops into the previous node. |
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
# Common nodes are type of nodes that could be seen as attributes and remain |
|
|
|
@ -136,7 +140,7 @@ class CheckpointSolverBase(ABC):
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
def _is_inplace(n: Node): |
|
|
|
|
"""Get the inplace argument from torch.fx.Node |
|
|
|
|
"""Get the inplace argument from ``torch.fx.Node`` |
|
|
|
|
""" |
|
|
|
|
inplace = False |
|
|
|
|
if n.op == "call_function": |
|
|
|
|