[fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. (#1446)

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* mend

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.
pull/1439/head
Super Daniel 2 years ago committed by GitHub
parent 821c6172e2
commit d40a9392ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,45 +1,71 @@
from typing import Set, Tuple
import torch import torch
from torch.fx import GraphModule from torch.fx import GraphModule
import math
__all__ = ['chen_greedy', 'chen_sqrtn'] __all__ = ['chen_greedy', 'chen_sqrtn']
def chen_greedy(gm: GraphModule, B: int): def chen_greedy(gm: GraphModule) -> GraphModule:
""" """
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174. This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
Usage: Usage:
B = 5 * 1024 * 1024 * 1024 # An approximate memory budget of 5GB
model = resnet18() model = resnet18()
input_sample = torch.rand(4, 3, 224, 224) input_sample = torch.rand(4, 3, 224, 224)
gm = symbolic_trace(model) gm = symbolic_trace(model)
MetaInfoProp(gm).run(input_sample) MetaInfoProp(gm).run(input_sample)
gm = chen_greedy(gm, B) gm = chen_greedy(gm)
Args: Args:
gm (GraphModule): The module to add checkpoints gm (GraphModule): The module to add checkpoints
B (int): The approximate memory budget for this module.
""" """
def grid_search(num_grids: int = 6) -> Set:
"""
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = xy.
Grid search over [2/2 b, 2 b] for ckpt_opt over num_grids as in appendix A.
"""
_, b_approx = run_chen_greedy(0)
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
b_opt = math.inf
for b in range(b_min, b_max, (b_max - b_min) // num_grids):
ckpt, b_approx = run_chen_greedy(b)
if b_approx < b_opt:
b_opt = b_approx
ckpt_opt = ckpt
return ckpt_opt
def run_chen_greedy(b: int = 0) -> Tuple[Set, int]:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
"""
ckpt = set()
temp = 0
x = 0
y = 0
for (idx, n) in enumerate(gm.graph.nodes):
temp += getattr(n, 'activation_size')
y = max(y, temp)
if temp > b:
x += getattr(n, 'activation_size')
temp = 0
ckpt.add(idx)
return ckpt, math.floor(math.sqrt(x * y))
gm.graph.lint() # make sure nodes are in topological order gm.graph.lint() # make sure nodes are in topological order
temp = 0 ckpt = grid_search(num_grids=6)
x = 0 i = 0
idx = 0 for idx, n in enumerate(gm.graph.nodes):
budget = B if idx in ckpt:
for n in gm.graph.nodes: setattr(n, 'activation_checkpoint', str(i))
B -= getattr(n, 'param_size') i += 1
assert B > 0, f'The memory budget {budget / 1024 ** 3:.2f} GB is not enough for model parameters of {gm}'
for n in gm.graph.nodes:
temp += getattr(n, 'activation_size')
if temp > B:
x += getattr(n, 'activation_size')
temp = x
setattr(n, 'activation_checkpoint', str(idx))
idx += 1
gm.recompile() gm.recompile()
return gm return gm
def chen_sqrtn(gm: GraphModule): def chen_sqrtn(gm: GraphModule) -> GraphModule:
""" """
This is the theoretical optimal strategy in https://arxiv.org/abs/1604.06174. This is the theoretical optimal strategy in https://arxiv.org/abs/1604.06174.

@ -1,13 +1,13 @@
from ctypes import Union
from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn
import torch import torch
import torchvision.models as tm import torchvision.models as tm
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
from torch.fx import GraphModule from torch.fx import GraphModule
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from functools import partial
import pytest import pytest
SOLVERS = [partial(chen_greedy, B=1024 * 1024 * 64), chen_sqrtn] SOLVERS = [chen_greedy, chen_sqrtn]
def _is_activation_checkpoint_available(gm: GraphModule): def _is_activation_checkpoint_available(gm: GraphModule):
@ -16,6 +16,13 @@ def _is_activation_checkpoint_available(gm: GraphModule):
return True return True
def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule):
for m_p, gm_p in zip(m.parameters(), gm.parameters()):
if not torch.allclose(m_p, gm_p):
return False
return True
def test_ckpt_solver(): def test_ckpt_solver():
MODEL_LIST = [tm.resnet18, tm.densenet121] MODEL_LIST = [tm.resnet18, tm.densenet121]
@ -23,17 +30,24 @@ def test_ckpt_solver():
tracer = ColoTracer() tracer = ColoTracer()
data = torch.rand(1, 3, 224, 224) data = torch.rand(1, 3, 224, 224)
label = torch.rand(1, 1000)
for solver in SOLVERS: for solver in SOLVERS:
for model_cls in MODEL_LIST: for model_cls in MODEL_LIST:
model = model_cls() model = model_cls()
criterion = torch.nn.MSELoss()
graph = tracer.trace(root=model) graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__) gm = GraphModule(model, graph, model.__class__.__name__)
MetaInfoProp(gm).run(data) MetaInfoProp(gm).run(data)
gm = solver(gm) gm = solver(gm)
assert _is_activation_checkpoint_available( assert _is_activation_checkpoint_available(
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
assert torch.allclose(gm(data), model(data)) loss = criterion(model(data), label)
loss.backward()
loss = criterion(gm(data), label)
loss.backward()
assert _is_all_gradient_close(model,
gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}'
if __name__ == '__main__': if __name__ == '__main__':

Loading…
Cancel
Save