[fx] add rules to linearize computation graphs for searching. (#1461)

* [fx] polish ckpt_test.

* [fx] add rules to linearize computation graphs for searching.

* [fx] remove chen_sqrt for sake of simplicity

* [fx] fix inconsistencies.
pull/1467/head
Super Daniel 2022-08-17 14:47:12 +08:00 committed by GitHub
parent a7a3d55114
commit e7383f578b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 35 deletions

View File

@ -1 +1 @@
from .ckpt_solver_chen import chen_greedy, chen_sqrtn from .ckpt_solver_chen import chen_greedy

View File

@ -1,16 +1,33 @@
from typing import List, Set, Tuple from typing import List, Set, Tuple
import torch import torch
from torch.fx import GraphModule from torch.fx import GraphModule, Node
import math import math
__all__ = ['chen_greedy', 'chen_sqrtn'] __all__ = ['chen_greedy']
CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr']
def _all_potential_ckpt_nodes(gm: GraphModule) -> List: def _all_potential_ckpt_nodes(gm: GraphModule) -> List:
"""
In most existing frameworks of activation checkpoint, the forward graph is assumed to be linearized.
"""
def is_sink():
"""
If we can free all memories when executing a certain node, it is a sink.
"""
return not sum((v for k, v in deps.items()))
deps = {}
ckpt_nodes = [] ckpt_nodes = []
for n in gm.graph.nodes: for n in gm.graph.nodes:
if n.op == 'call_module': for n_par in n._input_nodes:
deps[n_par] -= 1 # free memory and dependencies
# We can only put act_ckpt on these nodes
if n.op in CKPT_OP and is_sink():
ckpt_nodes.append(n) ckpt_nodes.append(n)
deps[n] = len(n.users) # add dependencies for future executions
return ckpt_nodes return ckpt_nodes
@ -71,32 +88,7 @@ def chen_greedy(gm: GraphModule) -> GraphModule:
for i, seg in enumerate(ckpt): for i, seg in enumerate(ckpt):
for idx in range(*seg): for idx in range(*seg):
n = node_list[idx] n = node_list[idx]
if n.op in ['call_module', 'call_method', 'call_function']: if n.op in CKPT_OP:
setattr(n, 'activation_checkpoint', str(i)) setattr(n, 'activation_checkpoint', i)
gm.recompile()
return gm
def chen_sqrtn(gm: GraphModule) -> GraphModule:
"""
This is the theoretical optimal strategy in https://arxiv.org/abs/1604.06174.
Usage:
model = resnet18()
input_sample = torch.rand(4, 3, 224, 224)
gm = symbolic_trace(model)
MetaInfoProp(gm).run(input_sample)
gm = chen_sqrtn(gm)
Args:
gm (GraphModule): The module to add checkpoints
"""
gm.graph.lint() # make sure nodes are in topological order
k = int(len(gm.graph.nodes)**0.5) # take approximately sqrt(n) checkpoints
for idx, n in enumerate(gm.graph.nodes):
# We should not add act_ckpt to the placeholder
# The last segment should not be checkpointed
if n.op != 'placeholder' and (idx + 1) // k < k:
setattr(n, 'activation_checkpoint', str((idx + 1) // k))
gm.recompile() gm.recompile()
return gm return gm

View File

@ -1,5 +1,6 @@
from typing import Callable from typing import Callable
import copy import copy
import re
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torchvision.models as tm import torchvision.models as tm
@ -7,7 +8,7 @@ from torch.fx import GraphModule
import colossalai import colossalai
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
from colossalai.fx.passes.meta_info_prop import MetaInfoProp from colossalai.fx.passes.meta_info_prop import MetaInfoProp
from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn from colossalai.fx.passes.algorithms import chen_greedy
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
import pytest import pytest
@ -20,7 +21,7 @@ except:
from colossalai.fx.codegen import python_code_with_activation_checkpoint from colossalai.fx.codegen import python_code_with_activation_checkpoint
with_codegen = False with_codegen = False
SOLVERS = [chen_greedy, chen_sqrtn] SOLVERS = [chen_greedy]
def _is_activation_checkpoint_available(gm: GraphModule): def _is_activation_checkpoint_available(gm: GraphModule):
@ -36,6 +37,16 @@ def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule):
return True return True
def _is_graph_linearized(gm: GraphModule):
code = gm.code
# find patterns like r' return output_1, output_2', which is not expected on a linearized graph
pattern = re.compile(r' return [a-zA-Z0-9_]+(, [a-zA-Z0-9_]+)+')
if pattern.findall(code):
return False
else:
return True
def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule], def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule],
model_cls: Callable[[], torch.nn.Module]): model_cls: Callable[[], torch.nn.Module]):
criterion = torch.nn.MSELoss() criterion = torch.nn.MSELoss()
@ -66,12 +77,13 @@ def _run_ckpt_solver(rank):
codegen = ActivationCheckpointCodeGen() codegen = ActivationCheckpointCodeGen()
gm.graph.set_codegen(codegen) gm.graph.set_codegen(codegen)
gm = solver(gm) gm = solver(gm)
assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner."
assert _is_activation_checkpoint_available( assert _is_activation_checkpoint_available(
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
check_backward_consistency(m, gm, solver, model_cls) check_backward_consistency(m, gm, solver, model_cls)
gpc.destroy()
@pytest.mark.skip
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
def test_ckpt_solver(): def test_ckpt_solver():
mp.spawn(_run_ckpt_solver, nprocs=1) mp.spawn(_run_ckpt_solver, nprocs=1)
@ -94,12 +106,13 @@ def _run_ckpt_solver_torch11(rank):
MetaInfoProp(gm).run(data) MetaInfoProp(gm).run(data)
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph) gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
gm = solver(gm) gm = solver(gm)
assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner."
assert _is_activation_checkpoint_available( assert _is_activation_checkpoint_available(
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints" gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
check_backward_consistency(m, gm, solver, model_cls) check_backward_consistency(m, gm, solver, model_cls)
gpc.destroy()
@pytest.mark.skip
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') @pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
def test_ckpt_solver_torch11(): def test_ckpt_solver_torch11():
mp.spawn(_run_ckpt_solver_torch11, nprocs=1) mp.spawn(_run_ckpt_solver_torch11, nprocs=1)