mirror of https://github.com/hpcaitech/ColossalAI
parent
448248b27c
commit
cc55ff0aa4
|
@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch.fx import Graph, Node
|
from torch.fx import Graph, Node
|
||||||
|
|
||||||
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
|
from colossalai.fx.codegen.activation_checkpoint_codegen import ActivationCheckpointCodeGen
|
||||||
|
@ -17,13 +18,17 @@ def _copy_output(src: Graph, dst: Graph):
|
||||||
n_dst.meta = n_src.meta
|
n_dst.meta = n_src.meta
|
||||||
|
|
||||||
|
|
||||||
|
def _get_param_size(module: torch.nn.Module):
|
||||||
|
"""Get the size of the parameters in the module"""
|
||||||
|
return sum([p.numel() * torch.tensor([], dtype=p.dtype).element_size() for p in module.parameters()])
|
||||||
|
|
||||||
|
|
||||||
class CheckpointSolverBase(ABC):
|
class CheckpointSolverBase(ABC):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
graph: Graph,
|
graph: Graph,
|
||||||
memory_budget: float = -1.0,
|
free_memory: float = -1.0,
|
||||||
parameter_size: float = 0,
|
|
||||||
requires_linearize: bool = False,
|
requires_linearize: bool = False,
|
||||||
cnode: List[str] = None,
|
cnode: List[str] = None,
|
||||||
):
|
):
|
||||||
|
@ -37,8 +42,7 @@ class CheckpointSolverBase(ABC):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph (Graph): The computing graph to be optimized.
|
graph (Graph): The computing graph to be optimized.
|
||||||
memory_budget (float): Memory constraint for the solution.
|
free_memory (float): Memory constraint for the solution.
|
||||||
parameter_size (float): The size of parameter of this model. Use `parameter_size(model)` to estimate.
|
|
||||||
requires_linearize (bool): Whether the graph needs to be linearized.
|
requires_linearize (bool): Whether the graph needs to be linearized.
|
||||||
cnode (List[str], optional): Common node List, should be the subset of input. Default to None.
|
cnode (List[str], optional): Common node List, should be the subset of input. Default to None.
|
||||||
|
|
||||||
|
@ -58,8 +62,8 @@ class CheckpointSolverBase(ABC):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!")
|
"Nodes meta information hasn't been prepared! Please run MetaInfoProp before constructing the solver!")
|
||||||
|
|
||||||
self.memory_budget = memory_budget
|
self.free_memory = free_memory
|
||||||
self.parameter_size = parameter_size
|
self.parameter_size = _get_param_size(self.graph.owning_module)
|
||||||
self.cnode = cnode
|
self.cnode = cnode
|
||||||
self.requires_linearize = requires_linearize
|
self.requires_linearize = requires_linearize
|
||||||
if self.requires_linearize:
|
if self.requires_linearize:
|
||||||
|
|
|
@ -22,12 +22,7 @@ __all__ = ['CheckpointSolverRotor']
|
||||||
|
|
||||||
class CheckpointSolverRotor(CheckpointSolverBase):
|
class CheckpointSolverRotor(CheckpointSolverBase):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, graph: Graph, free_memory: float = -1, cnode: List[str] = None, memory_slots: int = 500):
|
||||||
graph: Graph,
|
|
||||||
memory_budget: float = -1,
|
|
||||||
parameter_size: float = 0,
|
|
||||||
cnode: List[str] = None,
|
|
||||||
memory_slots: int = 500):
|
|
||||||
"""This is the simple implementation of dynamic programming algorithm rotor
|
"""This is the simple implementation of dynamic programming algorithm rotor
|
||||||
in https://hal.inria.fr/hal-02352969. Some code are adapted from
|
in https://hal.inria.fr/hal-02352969. Some code are adapted from
|
||||||
https://gitlab.inria.fr/hiepacs/rotor.
|
https://gitlab.inria.fr/hiepacs/rotor.
|
||||||
|
@ -36,22 +31,22 @@ class CheckpointSolverRotor(CheckpointSolverBase):
|
||||||
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
|
Assume that we have a `GraphModule`, and we already applied the `MetaInfoProp`
|
||||||
to the graph to retrieve all information needed, then we could use the following
|
to the graph to retrieve all information needed, then we could use the following
|
||||||
code to find a solution using `CheckpointSolverRotor`:
|
code to find a solution using `CheckpointSolverRotor`:
|
||||||
>>> solver = CheckpointSolverRotor(gm.graph, memory_budget=memory_budget, parameter_size=parameter_size)
|
>>> solver = CheckpointSolverRotor(gm.graph, free_memory=torch.cuda.mem_get_info(device=0)[0])
|
||||||
>>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver
|
>>> rotor_graph = solver.solve(force_python=True) # otherwise use C solver
|
||||||
>>> gm.graph = rotor_graph # set the graph to a new graph
|
>>> gm.graph = rotor_graph # set the graph to a new graph
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph (Graph): The computing graph to be optimized.
|
graph (Graph): The computing graph to be optimized.
|
||||||
memory_budget (float, optional): Memory constraint for the solution, unit is byte.
|
free_memory (float, optional): Memory constraint for the solution, unit is byte.
|
||||||
parameter_size (float, optional): The size of parameter of this model, unit is byte. Use `parameter_size(model)` to estimate.
|
Use ``torch.cuda.mem_get_info(device=0)[0]`` to estimate the free_memory. Defaults to -1.
|
||||||
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
|
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None.
|
||||||
memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
|
memory_slots (int, optional): Number of slots for discretizing memory budget. Defaults to 500.
|
||||||
"""
|
"""
|
||||||
super().__init__(graph, memory_budget, parameter_size, True, cnode)
|
super().__init__(graph, free_memory, True, cnode)
|
||||||
self.memory_slots = memory_slots
|
self.memory_slots = memory_slots
|
||||||
|
|
||||||
# construct chain
|
# construct chain
|
||||||
unit = self.memory_budget // self.memory_slots
|
unit = self.free_memory // self.memory_slots
|
||||||
self.chain = self._construct_chain(self.graph, self.node_list)
|
self.chain = self._construct_chain(self.graph, self.node_list)
|
||||||
self.chain.discretize_all(unit)
|
self.chain.discretize_all(unit)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue