2022-08-15 11:09:19 +00:00
|
|
|
from typing import List, Set, Tuple
|
2022-08-11 07:46:39 +00:00
|
|
|
import torch
|
2022-08-17 06:47:12 +00:00
|
|
|
from torch.fx import GraphModule, Node
|
2022-08-12 03:28:50 +00:00
|
|
|
import math
|
2022-08-11 07:46:39 +00:00
|
|
|
|
2022-08-17 06:47:12 +00:00
|
|
|
__all__ = ['chen_greedy']
|
|
|
|
CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr']
|
2022-08-11 07:46:39 +00:00
|
|
|
|
|
|
|
|
2022-08-15 11:09:19 +00:00
|
|
|
def _all_potential_ckpt_nodes(gm: GraphModule) -> List:
|
2022-08-17 06:47:12 +00:00
|
|
|
"""
|
|
|
|
In most existing frameworks of activation checkpoint, the forward graph is assumed to be linearized.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def is_sink():
|
|
|
|
"""
|
|
|
|
If we can free all memories when executing a certain node, it is a sink.
|
|
|
|
"""
|
|
|
|
return not sum((v for k, v in deps.items()))
|
|
|
|
|
|
|
|
deps = {}
|
2022-08-15 11:09:19 +00:00
|
|
|
ckpt_nodes = []
|
|
|
|
for n in gm.graph.nodes:
|
2022-08-17 06:47:12 +00:00
|
|
|
for n_par in n._input_nodes:
|
|
|
|
deps[n_par] -= 1 # free memory and dependencies
|
|
|
|
|
|
|
|
# We can only put act_ckpt on these nodes
|
|
|
|
if n.op in CKPT_OP and is_sink():
|
2022-08-15 11:09:19 +00:00
|
|
|
ckpt_nodes.append(n)
|
2022-08-17 06:47:12 +00:00
|
|
|
deps[n] = len(n.users) # add dependencies for future executions
|
2022-08-15 11:09:19 +00:00
|
|
|
return ckpt_nodes
|
|
|
|
|
|
|
|
|
2022-08-12 03:28:50 +00:00
|
|
|
def chen_greedy(gm: GraphModule) -> GraphModule:
|
2022-08-11 07:46:39 +00:00
|
|
|
"""
|
|
|
|
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
|
2022-08-12 03:28:50 +00:00
|
|
|
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
|
2022-08-11 07:46:39 +00:00
|
|
|
|
|
|
|
Usage:
|
|
|
|
model = resnet18()
|
|
|
|
input_sample = torch.rand(4, 3, 224, 224)
|
|
|
|
gm = symbolic_trace(model)
|
|
|
|
MetaInfoProp(gm).run(input_sample)
|
2022-08-12 03:28:50 +00:00
|
|
|
gm = chen_greedy(gm)
|
2022-08-11 07:46:39 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
gm (GraphModule): The module to add checkpoints
|
|
|
|
"""
|
2022-08-12 03:28:50 +00:00
|
|
|
|
|
|
|
def grid_search(num_grids: int = 6) -> Set:
|
|
|
|
"""
|
|
|
|
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy.
|
|
|
|
Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A.
|
|
|
|
"""
|
|
|
|
_, b_approx = run_chen_greedy(0)
|
|
|
|
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
|
|
|
|
b_opt = math.inf
|
|
|
|
for b in range(b_min, b_max, (b_max - b_min) // num_grids):
|
2022-08-15 11:09:19 +00:00
|
|
|
ckpt_intv, b_approx = run_chen_greedy(b)
|
2022-08-12 03:28:50 +00:00
|
|
|
if b_approx < b_opt:
|
|
|
|
b_opt = b_approx
|
2022-08-15 11:09:19 +00:00
|
|
|
ckpt_opt = ckpt_intv
|
2022-08-12 03:28:50 +00:00
|
|
|
return ckpt_opt
|
|
|
|
|
|
|
|
def run_chen_greedy(b: int = 0) -> Tuple[Set, int]:
|
|
|
|
"""
|
|
|
|
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
|
|
|
|
"""
|
2022-08-15 11:09:19 +00:00
|
|
|
ckpt_nodes = _all_potential_ckpt_nodes(gm)
|
|
|
|
ckpt_intv = []
|
2022-08-12 03:28:50 +00:00
|
|
|
temp = 0
|
|
|
|
x = 0
|
|
|
|
y = 0
|
2022-08-15 11:09:19 +00:00
|
|
|
prev_idx = 2
|
2022-08-12 03:28:50 +00:00
|
|
|
for (idx, n) in enumerate(gm.graph.nodes):
|
2022-09-14 06:27:04 +00:00
|
|
|
n: Node
|
|
|
|
temp += n.meta['fwd_mem_out'] + n.meta['fwd_mem_tmp']
|
2022-08-12 03:28:50 +00:00
|
|
|
y = max(y, temp)
|
2022-08-15 11:09:19 +00:00
|
|
|
if temp > b and n in ckpt_nodes:
|
2022-09-14 06:27:04 +00:00
|
|
|
x += n.meta['fwd_mem_out']
|
2022-08-12 03:28:50 +00:00
|
|
|
temp = 0
|
2022-08-15 11:09:19 +00:00
|
|
|
ckpt_intv.append((prev_idx, idx + 1))
|
|
|
|
prev_idx = idx + 1
|
|
|
|
return ckpt_intv, math.floor(math.sqrt(x * y))
|
2022-08-12 03:28:50 +00:00
|
|
|
|
2022-08-11 07:46:39 +00:00
|
|
|
gm.graph.lint() # make sure nodes are in topological order
|
2022-08-12 03:28:50 +00:00
|
|
|
ckpt = grid_search(num_grids=6)
|
2022-08-15 11:09:19 +00:00
|
|
|
node_list = list(gm.graph.nodes)
|
|
|
|
for i, seg in enumerate(ckpt):
|
|
|
|
for idx in range(*seg):
|
|
|
|
n = node_list[idx]
|
2022-08-17 06:47:12 +00:00
|
|
|
if n.op in CKPT_OP:
|
|
|
|
setattr(n, 'activation_checkpoint', i)
|
2022-08-11 07:46:39 +00:00
|
|
|
gm.recompile()
|
|
|
|
return gm
|