[autoparallel] user-friendly API for CheckpointSolver. (#1879)

Merge for SC tutorial
pull/1797/head
Super Daniel 2022-11-10 20:59:28 +08:00 committed by GitHub
parent 448248b27c
commit cc55ff0aa4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 17 deletions

View File

@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from copy import deepcopy from copy import deepcopy
from typing import Any, List from typing import Any, List
import torch
from torch.fx import Graph, Node from torch.fx import Graph, Node
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
@ -17,13 +18,17 @@ def _copy_output(src: Graph, dst: Graph):
n_dst.meta = n_src.meta n_dst.meta = n_src.meta
def _get_param_size(module: torch.nn.Module):
"""Get the size of the parameters in the module"""
return sum([p.numel() * torch.tensor([], dtype=p.dtype).element_size() for p in module.parameters()])
class CheckpointSolverBase(ABC): class CheckpointSolverBase(ABC):
def __init__( def __init__(
self, self,
graph: Graph, graph: Graph,
memory_budget: float = -1.0, free_memory: float = -1.0,
parameter_size: float = 0,
requires_linearize: bool = False, requires_linearize: bool = False,
cnode: List[str] = None, cnode: List[str] = None,
): ):
@ -37,8 +42,7 @@ class CheckpointSolverBase(ABC):
Args: Args:
graph (Graph): The computing graph to be optimized. graph (Graph): The computing graph to be optimized.
memory_budget (float): Memory constraint for the solution. free_memory (float): Memory constraint for the solution.
parameter_size (float): The size of parameter of this model. Use `parameter_size(model)` to estimate.
requires_linearize (bool): Whether the graph needs to be linearized. requires_linearize (bool): Whether the graph needs to be linearized.
cnode (List[str], optional): Common node List, should be the subset of input. Default to None. cnode (List[str], optional): Common node List, should be the subset of input. Default to None.
@ -58,8 +62,8 @@ class CheckpointSolverBase(ABC):
raise RuntimeError( raise RuntimeError(
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!") "Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!")
self.memory_budget = memory_budget self.free_memory = free_memory
self.parameter_size = parameter_size self.parameter_size = _get_param_size(self.graph.owning_module)
self.cnode = cnode self.cnode = cnode
self.requires_linearize = requires_linearize self.requires_linearize = requires_linearize
if self.requires_linearize: if self.requires_linearize:

View File

@ -22,12 +22,7 @@ __all__ = ['CheckpointSolverRotor']
class CheckpointSolverRotor(CheckpointSolverBase): class CheckpointSolverRotor(CheckpointSolverBase):
def __init__(self, def __init__(self, graph: Graph, free_memory: float = -1, cnode: List[str] = None, memory_slots: int = 500):
graph: Graph,
memory_budget: float = -1,
parameter_size: float = 0,
cnode: List[str] = None,
memory_slots: int = 500):
"""This is the simple implementation of dynamic programming algorithm rotor """This is the simple implementation of dynamic programming algorithm rotor
in https://hal.inria.fr/hal-02352969. Some code are adapted from in https://hal.inria.fr/hal-02352969. Some code are adapted from
https://gitlab.inria.fr/hiepacs/rotor. https://gitlab.inria.fr/hiepacs/rotor.
@ -36,22 +31,22 @@ class CheckpointSolverRotor(CheckpointSolverBase):
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp` Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
to the graph to retrieve all information needed, then we could use the following to the graph to retrieve all information needed, then we could use the following
code to find a solution using `CheckpointSolverRotor`: code to find a solution using `CheckpointSolverRotor`:
>>> solver = CheckpointSolverRotor(gm.graph, memory_budget=memory_budget, parameter_size=parameter_size) >>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0])
>>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver >>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver
>>> gm.graph = rotor_graph # set the graph to a new graph >>> gm.graph = rotor_graph # set the graph to a new graph
Args: Args:
graph (Graph): The computing graph to be optimized. graph (Graph): The computing graph to be optimized.
memory_budget (float, optional): Memory constraint for the solution, unit is byte. free_memory (float, optional): Memory constraint for the solution, unit is byte.
parameter_size (float, optional): The size of parameter of this model, unit is byte. Use `parameter_size(model)` to estimate. Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1.
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None. cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500. memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
""" """
super().__init__(graph, memory_budget, parameter_size, True, cnode) super().__init__(graph, free_memory, True, cnode)
self.memory_slots = memory_slots self.memory_slots = memory_slots
# construct chain # construct chain
unit = self.memory_budget // self.memory_slots unit = self.free_memory // self.memory_slots
self.chain = self._construct_chain(self.graph, self.node_list) self.chain = self._construct_chain(self.graph, self.node_list)
self.chain.discretize_all(unit) self.chain.discretize_all(unit)