mirror of https://github.com/hpcaitech/ColossalAI
[fx] add rules to linearize computation graphs for searching. (#1461)
* [fx] polish ckpt_test. * [fx] add rules to linearize computation graphs for searching. * [fx] remove chen_sqrt for sake of simplicity * [fx] fix inconsistencies.pull/1467/head
parent
a7a3d55114
commit
e7383f578b
|
@ -1 +1 @@
|
|||
from .ckpt_solver_chen import chen_greedy, chen_sqrtn
|
||||
from .ckpt_solver_chen import chen_greedy
|
||||
|
|
|
@ -1,16 +1,33 @@
|
|||
from typing import List, Set, Tuple
|
||||
import torch
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx import GraphModule, Node
|
||||
import math
|
||||
|
||||
__all__ = ['chen_greedy', 'chen_sqrtn']
|
||||
__all__ = ['chen_greedy']
|
||||
CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr']
|
||||
|
||||
|
||||
def _all_potential_ckpt_nodes(gm: GraphModule) -> List:
|
||||
"""
|
||||
In most existing frameworks of activation checkpoint, the forward graph is assumed to be linearized.
|
||||
"""
|
||||
|
||||
def is_sink():
|
||||
"""
|
||||
If we can free all memories when executing a certain node, it is a sink.
|
||||
"""
|
||||
return not sum((v for k, v in deps.items()))
|
||||
|
||||
deps = {}
|
||||
ckpt_nodes = []
|
||||
for n in gm.graph.nodes:
|
||||
if n.op == 'call_module':
|
||||
for n_par in n._input_nodes:
|
||||
deps[n_par] -= 1 # free memory and dependencies
|
||||
|
||||
# We can only put act_ckpt on these nodes
|
||||
if n.op in CKPT_OP and is_sink():
|
||||
ckpt_nodes.append(n)
|
||||
deps[n] = len(n.users) # add dependencies for future executions
|
||||
return ckpt_nodes
|
||||
|
||||
|
||||
|
@ -71,32 +88,7 @@ def chen_greedy(gm: GraphModule) -> GraphModule:
|
|||
for i, seg in enumerate(ckpt):
|
||||
for idx in range(*seg):
|
||||
n = node_list[idx]
|
||||
if n.op in ['call_module', 'call_method', 'call_function']:
|
||||
setattr(n, 'activation_checkpoint', str(i))
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
|
||||
def chen_sqrtn(gm: GraphModule) -> GraphModule:
|
||||
"""
|
||||
This is the theoretical optimal strategy in https://arxiv.org/abs/1604.06174.
|
||||
|
||||
Usage:
|
||||
model = resnet18()
|
||||
input_sample = torch.rand(4, 3, 224, 224)
|
||||
gm = symbolic_trace(model)
|
||||
MetaInfoProp(gm).run(input_sample)
|
||||
gm = chen_sqrtn(gm)
|
||||
|
||||
Args:
|
||||
gm (GraphModule): The module to add checkpoints
|
||||
"""
|
||||
gm.graph.lint() # make sure nodes are in topological order
|
||||
k = int(len(gm.graph.nodes)**0.5) # take approximately sqrt(n) checkpoints
|
||||
for idx, n in enumerate(gm.graph.nodes):
|
||||
# We should not add act_ckpt to the placeholder
|
||||
# The last segment should not be checkpointed
|
||||
if n.op != 'placeholder' and (idx + 1) // k < k:
|
||||
setattr(n, 'activation_checkpoint', str((idx + 1) // k))
|
||||
if n.op in CKPT_OP:
|
||||
setattr(n, 'activation_checkpoint', i)
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from typing import Callable
|
||||
import copy
|
||||
import re
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torchvision.models as tm
|
||||
|
@ -7,7 +8,7 @@ from torch.fx import GraphModule
|
|||
import colossalai
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn
|
||||
from colossalai.fx.passes.algorithms import chen_greedy
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
import pytest
|
||||
|
@ -20,7 +21,7 @@ except:
|
|||
from colossalai.fx.codegen import python_code_with_activation_checkpoint
|
||||
with_codegen = False
|
||||
|
||||
SOLVERS = [chen_greedy, chen_sqrtn]
|
||||
SOLVERS = [chen_greedy]
|
||||
|
||||
|
||||
def _is_activation_checkpoint_available(gm: GraphModule):
|
||||
|
@ -36,6 +37,16 @@ def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule):
|
|||
return True
|
||||
|
||||
|
||||
def _is_graph_linearized(gm: GraphModule):
|
||||
code = gm.code
|
||||
# find patterns like r' return output_1, output_2', which is not expected on a linearized graph
|
||||
pattern = re.compile(r' return [a-zA-Z0-9_]+(, [a-zA-Z0-9_]+)+')
|
||||
if pattern.findall(code):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule],
|
||||
model_cls: Callable[[], torch.nn.Module]):
|
||||
criterion = torch.nn.MSELoss()
|
||||
|
@ -66,12 +77,13 @@ def _run_ckpt_solver(rank):
|
|||
codegen = ActivationCheckpointCodeGen()
|
||||
gm.graph.set_codegen(codegen)
|
||||
gm = solver(gm)
|
||||
assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner."
|
||||
assert _is_activation_checkpoint_available(
|
||||
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
|
||||
check_backward_consistency(m, gm, solver, model_cls)
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
|
||||
def test_ckpt_solver():
|
||||
mp.spawn(_run_ckpt_solver, nprocs=1)
|
||||
|
@ -94,12 +106,13 @@ def _run_ckpt_solver_torch11(rank):
|
|||
MetaInfoProp(gm).run(data)
|
||||
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
|
||||
gm = solver(gm)
|
||||
assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner."
|
||||
assert _is_activation_checkpoint_available(
|
||||
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
|
||||
check_backward_consistency(m, gm, solver, model_cls)
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
|
||||
def test_ckpt_solver_torch11():
|
||||
mp.spawn(_run_ckpt_solver_torch11, nprocs=1)
|
||||
|
|
Loading…
Reference in New Issue