ColossalAI/colossalai/fx/passes/algorithms/ckpt_solver_chen.py

95 lines
3.1 KiB
Python
Raw Normal View History

from typing import List, Set, Tuple
import torch
from torch.fx import GraphModule, Node
import math
__all__ = ['chen_greedy']
CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr']
def _all_potential_ckpt_nodes(gm: GraphModule) -> List:
"""
In most existing frameworks of activation checkpoint, the forward graph is assumed to be linearized.
"""
def is_sink():
"""
If we can free all memories when executing a certain node, it is a sink.
"""
return not sum((v for k, v in deps.items()))
deps = {}
ckpt_nodes = []
for n in gm.graph.nodes:
for n_par in n._input_nodes:
deps[n_par] -= 1 # free memory and dependencies
# We can only put act_ckpt on these nodes
if n.op in CKPT_OP and is_sink():
ckpt_nodes.append(n)
deps[n] = len(n.users) # add dependencies for future executions
return ckpt_nodes
def chen_greedy(gm: GraphModule) -> GraphModule:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
Usage:
model = resnet18()
input_sample = torch.rand(4, 3, 224, 224)
gm = symbolic_trace(model)
MetaInfoProp(gm).run(input_sample)
gm = chen_greedy(gm)
Args:
gm (GraphModule): The module to add checkpoints
"""
def grid_search(num_grids: int = 6) -> Set:
"""
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = xy.
Grid search over [2/2 b, 2 b] for ckpt_opt over num_grids as in appendix A.
"""
_, b_approx = run_chen_greedy(0)
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
b_opt = math.inf
for b in range(b_min, b_max, (b_max - b_min) // num_grids):
ckpt_intv, b_approx = run_chen_greedy(b)
if b_approx < b_opt:
b_opt = b_approx
ckpt_opt = ckpt_intv
return ckpt_opt
def run_chen_greedy(b: int = 0) -> Tuple[Set, int]:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
"""
ckpt_nodes = _all_potential_ckpt_nodes(gm)
ckpt_intv = []
temp = 0
x = 0
y = 0
prev_idx = 2
for (idx, n) in enumerate(gm.graph.nodes):
[fx] add profiler for fx nodes. (#1480) * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] merge development into main (#1) * [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen. * [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. * [fx] fix lowercase naming conventions. * [fx] simplify test for ckpt. * [fx] add rules to linearize computation graphs for searching. (#2) * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] merge development into main (#1) * [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen. * [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. * [fx] fix lowercase naming conventions. * [fx] simplify test for ckpt. * [fx] fix test and algorithm bugs in activation checkpointing. * [fx] polish ckpt_test. * [fx] add rules to linearize computation graphs for searching. * [fx] remove chen_sqrt for sake of simplicity * [fx] remove chen_sqrt for sake of simplicity * [fx] remove chen_sqrt for sake of simplicity * [fx] remove chen_sqrt for sake of simplicity * [fx] fix inconsistencies. * [fx] fix MetaInfoProp. * [fx] fix MetaInfoProp. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] fix error in tests. * [fx] unfix bug. * [fx] unfix bug.
2022-08-24 08:22:44 +00:00
temp += getattr(n, '__activation__')
y = max(y, temp)
if temp > b and n in ckpt_nodes:
[fx] add profiler for fx nodes. (#1480) * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] merge development into main (#1) * [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen. * [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. * [fx] fix lowercase naming conventions. * [fx] simplify test for ckpt. * [fx] add rules to linearize computation graphs for searching. (#2) * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] merge development into main (#1) * [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen. * [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. * [fx] fix lowercase naming conventions. * [fx] simplify test for ckpt. * [fx] fix test and algorithm bugs in activation checkpointing. * [fx] polish ckpt_test. * [fx] add rules to linearize computation graphs for searching. * [fx] remove chen_sqrt for sake of simplicity * [fx] remove chen_sqrt for sake of simplicity * [fx] remove chen_sqrt for sake of simplicity * [fx] remove chen_sqrt for sake of simplicity * [fx] fix inconsistencies. * [fx] fix MetaInfoProp. * [fx] fix MetaInfoProp. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] consider MetaInfoProp for inplace operands. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] add profiler for fx nodes. * [fx] fix error in tests. * [fx] unfix bug. * [fx] unfix bug.
2022-08-24 08:22:44 +00:00
x += getattr(n, '__activation__')
temp = 0
ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1
return ckpt_intv, math.floor(math.sqrt(x * y))
gm.graph.lint() # make sure nodes are in topological order
ckpt = grid_search(num_grids=6)
node_list = list(gm.graph.nodes)
for i, seg in enumerate(ckpt):
for idx in range(*seg):
n = node_list[idx]
if n.op in CKPT_OP:
setattr(n, 'activation_checkpoint', i)
gm.recompile()
return gm