mirror of https://github.com/hpcaitech/ColossalAI
[fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. (#1446)
* [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] modify the calculation of node_size in MetaInfoProp for activation checkpointing usages * [fx] activation checkpointing using Chen strategies. * [fx] add test for ckpt_solver_chen * mend * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add vanilla activation checkpoint search with test on resnet and densenet * [fx] add a namespace code for solver_chen. * [fx] fix the false interpretation of algorithm 3 in https://arxiv.org/abs/1604.06174. * [fx] fix lowercase naming conventions.pull/1439/head
parent
821c6172e2
commit
d40a9392ba
|
@ -1,45 +1,71 @@
|
||||||
|
from typing import Set, Tuple
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
|
import math
|
||||||
|
|
||||||
__all__ = ['chen_greedy', 'chen_sqrtn']
|
__all__ = ['chen_greedy', 'chen_sqrtn']
|
||||||
|
|
||||||
|
|
||||||
def chen_greedy(gm: GraphModule, B: int):
|
def chen_greedy(gm: GraphModule) -> GraphModule:
|
||||||
"""
|
"""
|
||||||
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
|
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
|
||||||
|
Note that this algorithm targets at memory optimization only, using techniques in appendix A.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
B = 5 * 1024 * 1024 * 1024 # An approximate memory budget of 5GB
|
|
||||||
model = resnet18()
|
model = resnet18()
|
||||||
input_sample = torch.rand(4, 3, 224, 224)
|
input_sample = torch.rand(4, 3, 224, 224)
|
||||||
gm = symbolic_trace(model)
|
gm = symbolic_trace(model)
|
||||||
MetaInfoProp(gm).run(input_sample)
|
MetaInfoProp(gm).run(input_sample)
|
||||||
gm = chen_greedy(gm, B)
|
gm = chen_greedy(gm)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
gm (GraphModule): The module to add checkpoints
|
gm (GraphModule): The module to add checkpoints
|
||||||
B (int): The approximate memory budget for this module.
|
|
||||||
"""
|
"""
|
||||||
gm.graph.lint() # make sure nodes are in topological order
|
|
||||||
|
def grid_search(num_grids: int = 6) -> Set:
|
||||||
|
"""
|
||||||
|
Search ckpt strategy with b = 0, then run the allocation algorithm again with b = √xy.
|
||||||
|
Grid search over [√2/2 b, √2 b] for ckpt_opt over num_grids as in appendix A.
|
||||||
|
"""
|
||||||
|
_, b_approx = run_chen_greedy(0)
|
||||||
|
b_min, b_max = math.floor(b_approx / math.sqrt(2)), math.ceil(b_approx * math.sqrt(2))
|
||||||
|
b_opt = math.inf
|
||||||
|
for b in range(b_min, b_max, (b_max - b_min) // num_grids):
|
||||||
|
ckpt, b_approx = run_chen_greedy(b)
|
||||||
|
if b_approx < b_opt:
|
||||||
|
b_opt = b_approx
|
||||||
|
ckpt_opt = ckpt
|
||||||
|
return ckpt_opt
|
||||||
|
|
||||||
|
def run_chen_greedy(b: int = 0) -> Tuple[Set, int]:
|
||||||
|
"""
|
||||||
|
This is the simple implementation of Algorithm 3 in https://arxiv.org/abs/1604.06174.
|
||||||
|
"""
|
||||||
|
ckpt = set()
|
||||||
temp = 0
|
temp = 0
|
||||||
x = 0
|
x = 0
|
||||||
idx = 0
|
y = 0
|
||||||
budget = B
|
for (idx, n) in enumerate(gm.graph.nodes):
|
||||||
for n in gm.graph.nodes:
|
|
||||||
B -= getattr(n, 'param_size')
|
|
||||||
assert B > 0, f'The memory budget {budget / 1024 ** 3:.2f} GB is not enough for model parameters of {gm}'
|
|
||||||
for n in gm.graph.nodes:
|
|
||||||
temp += getattr(n, 'activation_size')
|
temp += getattr(n, 'activation_size')
|
||||||
if temp > B:
|
y = max(y, temp)
|
||||||
|
if temp > b:
|
||||||
x += getattr(n, 'activation_size')
|
x += getattr(n, 'activation_size')
|
||||||
temp = x
|
temp = 0
|
||||||
setattr(n, 'activation_checkpoint', str(idx))
|
ckpt.add(idx)
|
||||||
idx += 1
|
return ckpt, math.floor(math.sqrt(x * y))
|
||||||
|
|
||||||
|
gm.graph.lint() # make sure nodes are in topological order
|
||||||
|
ckpt = grid_search(num_grids=6)
|
||||||
|
i = 0
|
||||||
|
for idx, n in enumerate(gm.graph.nodes):
|
||||||
|
if idx in ckpt:
|
||||||
|
setattr(n, 'activation_checkpoint', str(i))
|
||||||
|
i += 1
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
|
||||||
def chen_sqrtn(gm: GraphModule):
|
def chen_sqrtn(gm: GraphModule) -> GraphModule:
|
||||||
"""
|
"""
|
||||||
This is the theoretical optimal strategy in https://arxiv.org/abs/1604.06174.
|
This is the theoretical optimal strategy in https://arxiv.org/abs/1604.06174.
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
|
from ctypes import Union
|
||||||
from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn
|
from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn
|
||||||
import torch
|
import torch
|
||||||
import torchvision.models as tm
|
import torchvision.models as tm
|
||||||
from colossalai.fx import ColoTracer
|
from colossalai.fx import ColoTracer
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
from functools import partial
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
SOLVERS = [partial(chen_greedy, B=1024 * 1024 * 64), chen_sqrtn]
|
SOLVERS = [chen_greedy, chen_sqrtn]
|
||||||
|
|
||||||
|
|
||||||
def _is_activation_checkpoint_available(gm: GraphModule):
|
def _is_activation_checkpoint_available(gm: GraphModule):
|
||||||
|
@ -16,6 +16,13 @@ def _is_activation_checkpoint_available(gm: GraphModule):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule):
|
||||||
|
for m_p, gm_p in zip(m.parameters(), gm.parameters()):
|
||||||
|
if not torch.allclose(m_p, gm_p):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def test_ckpt_solver():
|
def test_ckpt_solver():
|
||||||
MODEL_LIST = [tm.resnet18, tm.densenet121]
|
MODEL_LIST = [tm.resnet18, tm.densenet121]
|
||||||
|
|
||||||
|
@ -23,17 +30,24 @@ def test_ckpt_solver():
|
||||||
|
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
data = torch.rand(1, 3, 224, 224)
|
data = torch.rand(1, 3, 224, 224)
|
||||||
|
label = torch.rand(1, 1000)
|
||||||
|
|
||||||
for solver in SOLVERS:
|
for solver in SOLVERS:
|
||||||
for model_cls in MODEL_LIST:
|
for model_cls in MODEL_LIST:
|
||||||
model = model_cls()
|
model = model_cls()
|
||||||
|
criterion = torch.nn.MSELoss()
|
||||||
graph = tracer.trace(root=model)
|
graph = tracer.trace(root=model)
|
||||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
MetaInfoProp(gm).run(data)
|
MetaInfoProp(gm).run(data)
|
||||||
gm = solver(gm)
|
gm = solver(gm)
|
||||||
assert _is_activation_checkpoint_available(
|
assert _is_activation_checkpoint_available(
|
||||||
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
|
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
|
||||||
assert torch.allclose(gm(data), model(data))
|
loss = criterion(model(data), label)
|
||||||
|
loss.backward()
|
||||||
|
loss = criterion(gm(data), label)
|
||||||
|
loss.backward()
|
||||||
|
assert _is_all_gradient_close(model,
|
||||||
|
gm), f'Solver {solver} did not work correctly in backward pass on {model_cls}'
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
Loading…
Reference in New Issue