mirror of https://github.com/hpcaitech/ColossalAI
[fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. (#1446)
* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * mend * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen. * [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. * [fx] fix lowercase naming conventions.pull/1439/head
parent
821c6172e2
commit
d40a9392ba
|
@ -1,45 +1,71 @@
|
|||
from typing import Set, Tuple
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
import math
|
||||
|
||||
__all__ = ['chen_greedy', 'chen_sqrtn']
|
||||
|
||||
|
||||
def chen_greedy(gm: GraphModule, B: int):
|
||||
def chen_greedy(gm: GraphModule) -> GraphModule:
|
||||
"""
|
||||
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
|
||||
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
|
||||
|
||||
Usage:
|
||||
B = 5 * 1024 * 1024 * 1024 # An approximate memory budget of 5GB
|
||||
model = resnet18()
|
||||
input_sample = torch.rand(4, 3, 224, 224)
|
||||
gm = symbolic_trace(model)
|
||||
MetaInfoProp(gm).run(input_sample)
|
||||
gm = chen_greedy(gm, B)
|
||||
gm = chen_greedy(gm)
|
||||
|
||||
Args:
|
||||
gm (GraphModule): The module to add checkpoints
|
||||
B (int): The approximate memory budget for this module.
|
||||
"""
|
||||
|
||||
def grid_search(num_grids: int = 6) -> Set:
|
||||
"""
|
||||
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy.
|
||||
Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A.
|
||||
"""
|
||||
_, b_approx = run_chen_greedy(0)
|
||||
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
|
||||
b_opt = math.inf
|
||||
for b in range(b_min, b_max, (b_max - b_min) // num_grids):
|
||||
ckpt, b_approx = run_chen_greedy(b)
|
||||
if b_approx < b_opt:
|
||||
b_opt = b_approx
|
||||
ckpt_opt = ckpt
|
||||
return ckpt_opt
|
||||
|
||||
def run_chen_greedy(b: int = 0) -> Tuple[Set, int]:
|
||||
"""
|
||||
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
|
||||
"""
|
||||
ckpt = set()
|
||||
temp = 0
|
||||
x = 0
|
||||
y = 0
|
||||
for (idx, n) in enumerate(gm.graph.nodes):
|
||||
temp += getattr(n, 'activation_size')
|
||||
y = max(y, temp)
|
||||
if temp > b:
|
||||
x += getattr(n, 'activation_size')
|
||||
temp = 0
|
||||
ckpt.add(idx)
|
||||
return ckpt, math.floor(math.sqrt(x * y))
|
||||
|
||||
gm.graph.lint() # make sure nodes are in topological order
|
||||
temp = 0
|
||||
x = 0
|
||||
idx = 0
|
||||
budget = B
|
||||
for n in gm.graph.nodes:
|
||||
B -= getattr(n, 'param_size')
|
||||
assert B > 0, f'The memory budget {budget / 1024 ** 3:.2f} GB is not enough for model parameters of {gm}'
|
||||
for n in gm.graph.nodes:
|
||||
temp += getattr(n, 'activation_size')
|
||||
if temp > B:
|
||||
x += getattr(n, 'activation_size')
|
||||
temp = x
|
||||
setattr(n, 'activation_checkpoint', str(idx))
|
||||
idx += 1
|
||||
ckpt = grid_search(num_grids=6)
|
||||
i = 0
|
||||
for idx, n in enumerate(gm.graph.nodes):
|
||||
if idx in ckpt:
|
||||
setattr(n, 'activation_checkpoint', str(i))
|
||||
i += 1
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
|
||||
def chen_sqrtn(gm: GraphModule):
|
||||
def chen_sqrtn(gm: GraphModule) -> GraphModule:
|
||||
"""
|
||||
This is the theoretical optimal strategy in https://arxiv.org/abs/1604.06174.
|
||||
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
from ctypes import Union
|
||||
from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn
|
||||
import torch
|
||||
import torchvision.models as tm
|
||||
from colossalai.fx import ColoTracer
|
||||
from torch.fx import GraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from functools import partial
|
||||
import pytest
|
||||
|
||||
SOLVERS = [partial(chen_greedy, B=1024 * 1024 * 64), chen_sqrtn]
|
||||
SOLVERS = [chen_greedy, chen_sqrtn]
|
||||
|
||||
|
||||
def _is_activation_checkpoint_available(gm: GraphModule):
|
||||
|
@ -16,6 +16,13 @@ def _is_activation_checkpoint_available(gm: GraphModule):
|
|||
return True
|
||||
|
||||
|
||||
def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule):
|
||||
for m_p, gm_p in zip(m.parameters(), gm.parameters()):
|
||||
if not torch.allclose(m_p, gm_p):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_ckpt_solver():
|
||||
MODEL_LIST = [tm.resnet18, tm.densenet121]
|
||||
|
||||
|
@ -23,17 +30,24 @@ def test_ckpt_solver():
|
|||
|
||||
tracer = ColoTracer()
|
||||
data = torch.rand(1, 3, 224, 224)
|
||||
label = torch.rand(1, 1000)
|
||||
|
||||
for solver in SOLVERS:
|
||||
for model_cls in MODEL_LIST:
|
||||
model = model_cls()
|
||||
criterion = torch.nn.MSELoss()
|
||||
graph = tracer.trace(root=model)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
MetaInfoProp(gm).run(data)
|
||||
gm = solver(gm)
|
||||
assert _is_activation_checkpoint_available(
|
||||
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
|
||||
assert torch.allclose(gm(data), model(data))
|
||||
loss = criterion(model(data), label)
|
||||
loss.backward()
|
||||
loss = criterion(gm(data), label)
|
||||
loss.backward()
|
||||
assert _is_all_gradient_close(model,
|
||||
gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}'
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue