[fx] fix test and algorithm bugs in activation checkpointing. (#1451)

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages

* [fx] merge development into main (#1)

* [fx] activation checkpointing using Chen strategies.

* [fx] add test for ckpt_solver_chen

* [fx] add vanilla activation checkpoint search with test on resnet and densenet

* [fx] add a namespace code for solver_chen.

* [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174.

* [fx] fix lowercase naming conventions.

* [fx] simplify test for ckpt.

* [fx] fix test and algorithm bugs in activation checkpointing.

* mend

[fx] fix test and algorithm bugs in activation checkpointing.

* mend

[fx] fix test and algorithm bugs in activation checkpointing.

* mend

[fx] fix test and algorithm bugs in activation checkpointing.

* mend

[fx] fix test and algorithm bugs in activation checkpointing.

* [fx] polish ckpt_test.

* [fx] polish ckpt_test.

* [fx] polish ckpt_test.
pull/1473/head
Super Daniel 2022-08-15 19:09:19 +08:00 committed by GitHub
parent b1553fdf96
commit 0dbd61c29b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 101 additions and 31 deletions

View File

@ -1,4 +1,4 @@
from typing import Set, Tuple
from typing import List, Set, Tuple
import torch
from torch.fx import GraphModule
import math
@ -6,6 +6,14 @@ import math
__all__ = ['chen_greedy', 'chen_sqrtn']
def _all_potential_ckpt_nodes(gm: GraphModule) -> List:
ckpt_nodes = []
for n in gm.graph.nodes:
if n.op == 'call_module':
ckpt_nodes.append(n)
return ckpt_nodes
def chen_greedy(gm: GraphModule) -> GraphModule:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
@ -31,36 +39,40 @@ def chen_greedy(gm: GraphModule) -> GraphModule:
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
b_opt = math.inf
for b in range(b_min, b_max, (b_max - b_min) // num_grids):
ckpt, b_approx = run_chen_greedy(b)
ckpt_intv, b_approx = run_chen_greedy(b)
if b_approx < b_opt:
b_opt = b_approx
ckpt_opt = ckpt
ckpt_opt = ckpt_intv
return ckpt_opt
def run_chen_greedy(b: int = 0) -> Tuple[Set, int]:
"""
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
"""
ckpt = set()
ckpt_nodes = _all_potential_ckpt_nodes(gm)
ckpt_intv = []
temp = 0
x = 0
y = 0
prev_idx = 2
for (idx, n) in enumerate(gm.graph.nodes):
temp += getattr(n, 'activation_size')
y = max(y, temp)
if temp > b:
if temp > b and n in ckpt_nodes:
x += getattr(n, 'activation_size')
temp = 0
ckpt.add(idx)
return ckpt, math.floor(math.sqrt(x * y))
ckpt_intv.append((prev_idx, idx + 1))
prev_idx = idx + 1
return ckpt_intv, math.floor(math.sqrt(x * y))
gm.graph.lint() # make sure nodes are in topological order
ckpt = grid_search(num_grids=6)
i = 0
for idx, n in enumerate(gm.graph.nodes):
if idx in ckpt:
setattr(n, 'activation_checkpoint', str(i))
i += 1
node_list = list(gm.graph.nodes)
for i, seg in enumerate(ckpt):
for idx in range(*seg):
n = node_list[idx]
if n.op in ['call_module', 'call_method', 'call_function']:
setattr(n, 'activation_checkpoint', str(i))
gm.recompile()
return gm
@ -82,7 +94,9 @@ def chen_sqrtn(gm: GraphModule) -> GraphModule:
gm.graph.lint() # make sure nodes are in topological order
k = int(len(gm.graph.nodes)**0.5) # take approximately sqrt(n) checkpoints
for idx, n in enumerate(gm.graph.nodes):
if (idx + 1) % k == 0:
# We should not add act_ckpt to the placeholder
# The last segment should not be checkpointed
if n.op != 'placeholder' and (idx + 1) // k < k:
setattr(n, 'activation_checkpoint', str((idx + 1) // k))
gm.recompile()
return gm

View File

@ -1,12 +1,25 @@
from ctypes import Union
from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn
from typing import Callable
import copy
import torch
import torch.multiprocessing as mp
import torchvision.models as tm
from colossalai.fx import ColoTracer
from torch.fx import GraphModule
import colossalai
from colossalai.fx import ColoTracer
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
import pytest
try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
with_codegen = True
except:
# fall back to older pytorch version
from colossalai.fx.codegen import python_code_with_activation_checkpoint
with_codegen = False
SOLVERS = [chen_greedy, chen_sqrtn]
@ -18,37 +31,80 @@ def _is_activation_checkpoint_available(gm: GraphModule):
def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule):
for m_p, gm_p in zip(m.parameters(), gm.parameters()):
if not torch.allclose(m_p, gm_p):
if not torch.allclose(m_p.grad, gm_p.grad):
return False
return True
def test_ckpt_solver():
def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule],
model_cls: Callable[[], torch.nn.Module]):
criterion = torch.nn.MSELoss()
data = torch.rand(2, 3, 32, 32)
label = torch.rand(2, 5)
loss = criterion(m(data), label)
loss.backward()
loss = criterion(gm(data), label)
loss.backward()
assert _is_all_gradient_close(m, gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}'
def _run_ckpt_solver(rank):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
MODEL_LIST = [tm.resnet18, tm.densenet121]
torch.backends.cudnn.deterministic = True
tracer = ColoTracer()
data = torch.rand(1, 3, 224, 224)
label = torch.rand(1, 1000)
tracer = ColoTracer(trace_act_ckpt=False)
data = torch.rand(2, 3, 32, 32)
for solver in SOLVERS:
for model_cls in MODEL_LIST:
model = model_cls()
criterion = torch.nn.MSELoss()
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
m = model_cls(num_classes=5)
graph = tracer.trace(root=m)
gm = GraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
MetaInfoProp(gm).run(data)
codegen = ActivationCheckpointCodeGen()
gm.graph.set_codegen(codegen)
gm = solver(gm)
assert _is_activation_checkpoint_available(
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
loss = criterion(model(data), label)
loss.backward()
loss = criterion(gm(data), label)
loss.backward()
assert _is_all_gradient_close(model,
gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}'
check_backward_consistency(m, gm, solver, model_cls)
@pytest.mark.skip
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
def test_ckpt_solver():
mp.spawn(_run_ckpt_solver, nprocs=1)
def _run_ckpt_solver_torch11(rank):
colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl')
MODEL_LIST = [tm.resnet18, tm.densenet121]
torch.backends.cudnn.deterministic = True
tracer = ColoTracer(trace_act_ckpt=False)
data = torch.rand(2, 3, 32, 32)
for solver in SOLVERS:
for model_cls in MODEL_LIST:
m = model_cls(num_classes=5)
graph = tracer.trace(root=m)
gm = GraphModule(copy.deepcopy(m), graph, m.__class__.__name__)
MetaInfoProp(gm).run(data)
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
gm = solver(gm)
assert _is_activation_checkpoint_available(
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
check_backward_consistency(m, gm, solver, model_cls)
@pytest.mark.skip
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
def test_ckpt_solver_torch11():
mp.spawn(_run_ckpt_solver_torch11, nprocs=1)
if __name__ == '__main__':
test_ckpt_solver()
test_ckpt_solver_torch11()