mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
87 lines
3.4 KiB
87 lines
3.4 KiB
import math |
|
from copy import deepcopy |
|
from typing import List, Set, Tuple |
|
|
|
from torch.fx import Graph, Node |
|
|
|
from colossalai.fx.profiler import calculate_fwd_in, calculate_fwd_tmp |
|
|
|
from .ckpt_solver_base import CheckpointSolverBase |
|
|
|
__all__ = ['CheckpointSolverChen'] |
|
|
|
|
|
class CheckpointSolverChen(CheckpointSolverBase): |
|
|
|
def __init__(self, graph: Graph, cnode: List[str] = None, num_grids: int = 6): |
|
""" |
|
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. |
|
Note that this algorithm targets at memory optimization only, using techniques in appendix A. |
|
|
|
Usage: |
|
Assume that we have a ``GraphModule``, and we have already done the extractions |
|
to the graph to retrieve all information needed, then we could use the following |
|
code to find a solution using ``CheckpointSolverChen``: |
|
>>> solver = CheckpointSolverChen(gm.graph) |
|
>>> chen_graph = solver.solve() |
|
>>> gm.graph = chen_graph # set the graph to a new graph |
|
|
|
Args: |
|
graph (Graph): The computing graph to be optimized. |
|
cnode (List[str], optional): Common node List, should be the subset of input. Defaults to None. |
|
num_grids (int, optional): Number of grids to search for b. Defaults to 6. |
|
""" |
|
super().__init__(graph, 0, 0, True, cnode) |
|
self.num_grids = num_grids |
|
|
|
def solve(self) -> Graph: |
|
"""Solve the checkpointing problem using Algorithm 3. |
|
|
|
Returns: |
|
graph (Graph): The optimized graph, should be a copy of the original graph. |
|
""" |
|
checkpointable_op = ['call_module', 'call_method', 'call_function', 'get_attr'] |
|
ckpt = self.grid_search() |
|
for i, seg in enumerate(ckpt): |
|
for idx in range(*seg): |
|
nodes = self.node_list[idx] |
|
for n in nodes: |
|
if n.op in checkpointable_op: |
|
n.meta['activation_checkpoint'] = i |
|
return deepcopy(self.graph) |
|
|
|
def run_chen_greedy(self, b: int = 0) -> Tuple[Set, int]: |
|
""" |
|
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. |
|
""" |
|
ckpt_intv = [] |
|
temp = 0 |
|
x = 0 |
|
y = 0 |
|
prev_idx = 2 |
|
for idx, nodes in enumerate(self.node_list): |
|
for n in nodes: |
|
n: Node |
|
temp += calculate_fwd_in(n) + calculate_fwd_tmp(n) |
|
y = max(y, temp) |
|
if temp > b and idx > prev_idx: |
|
x += calculate_fwd_in(nodes[0]) |
|
temp = 0 |
|
ckpt_intv.append((prev_idx, idx + 1)) |
|
prev_idx = idx + 1 |
|
return ckpt_intv, math.floor(math.sqrt(x * y)) |
|
|
|
def grid_search(self) -> Set: |
|
""" |
|
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy. |
|
Grid search over [√2/2 b, √2 b] for ``ckpt_opt`` over ``num_grids`` as in appendix A. |
|
""" |
|
_, b_approx = self.run_chen_greedy(0) |
|
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2)) |
|
b_opt = math.inf |
|
for b in range(b_min, b_max, (b_max - b_min) // self.num_grids): |
|
ckpt_intv, b_approx = self.run_chen_greedy(b) |
|
if b_approx < b_opt: |
|
b_opt = b_approx |
|
ckpt_opt = ckpt_intv |
|
return ckpt_opt
|
|
|