mirror of https://github.com/hpcaitech/ColossalAI
[fx] add rules to linearize computation graphs for searching. (#1461)
* [fx] polish ckpt_test. * [fx] add rules to linearize computation graphs for searching. * [fx] remove chen_sqrt for sake of simplicity * [fx] fix inconsistencies.pull/1467/head
parent
a7a3d55114
commit
e7383f578b
|
@ -1 +1 @@
|
||||||
from .ckpt_solver_chen import chen_greedy, chen_sqrtn
|
from .ckpt_solver_chen import chen_greedy
|
||||||
|
|
|
@ -1,16 +1,33 @@
|
||||||
from typing import List, Set, Tuple
|
from typing import List, Set, Tuple
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule, Node
|
||||||
import math
|
import math
|
||||||
|
|
||||||
__all__ = ['chen_greedy', 'chen_sqrtn']
|
__all__ = ['chen_greedy']
|
||||||
|
CKPT_OP = ['call_module', 'call_method', 'call_function', 'get_attr']
|
||||||
|
|
||||||
|
|
||||||
def _all_potential_ckpt_nodes(gm: GraphModule) -> List:
|
def _all_potential_ckpt_nodes(gm: GraphModule) -> List:
|
||||||
|
"""
|
||||||
|
In most existing frameworks of activation checkpoint, the forward graph is assumed to be linearized.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def is_sink():
|
||||||
|
"""
|
||||||
|
If we can free all memories when executing a certain node, it is a sink.
|
||||||
|
"""
|
||||||
|
return not sum((v for k, v in deps.items()))
|
||||||
|
|
||||||
|
deps = {}
|
||||||
ckpt_nodes = []
|
ckpt_nodes = []
|
||||||
for n in gm.graph.nodes:
|
for n in gm.graph.nodes:
|
||||||
if n.op == 'call_module':
|
for n_par in n._input_nodes:
|
||||||
|
deps[n_par] -= 1 # free memory and dependencies
|
||||||
|
|
||||||
|
# We can only put act_ckpt on these nodes
|
||||||
|
if n.op in CKPT_OP and is_sink():
|
||||||
ckpt_nodes.append(n)
|
ckpt_nodes.append(n)
|
||||||
|
deps[n] = len(n.users) # add dependencies for future executions
|
||||||
return ckpt_nodes
|
return ckpt_nodes
|
||||||
|
|
||||||
|
|
||||||
|
@ -71,32 +88,7 @@ def chen_greedy(gm: GraphModule) -> GraphModule:
|
||||||
for i, seg in enumerate(ckpt):
|
for i, seg in enumerate(ckpt):
|
||||||
for idx in range(*seg):
|
for idx in range(*seg):
|
||||||
n = node_list[idx]
|
n = node_list[idx]
|
||||||
if n.op in ['call_module', 'call_method', 'call_function']:
|
if n.op in CKPT_OP:
|
||||||
setattr(n, 'activation_checkpoint', str(i))
|
setattr(n, 'activation_checkpoint', i)
|
||||||
gm.recompile()
|
|
||||||
return gm
|
|
||||||
|
|
||||||
|
|
||||||
def chen_sqrtn(gm: GraphModule) -> GraphModule:
|
|
||||||
"""
|
|
||||||
This is the theoretical optimal strategy in https://arxiv.org/abs/1604.06174.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
model = resnet18()
|
|
||||||
input_sample = torch.rand(4, 3, 224, 224)
|
|
||||||
gm = symbolic_trace(model)
|
|
||||||
MetaInfoProp(gm).run(input_sample)
|
|
||||||
gm = chen_sqrtn(gm)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
gm (GraphModule): The module to add checkpoints
|
|
||||||
"""
|
|
||||||
gm.graph.lint() # make sure nodes are in topological order
|
|
||||||
k = int(len(gm.graph.nodes)**0.5) # take approximately sqrt(n) checkpoints
|
|
||||||
for idx, n in enumerate(gm.graph.nodes):
|
|
||||||
# We should not add act_ckpt to the placeholder
|
|
||||||
# The last segment should not be checkpointed
|
|
||||||
if n.op != 'placeholder' and (idx + 1) // k < k:
|
|
||||||
setattr(n, 'activation_checkpoint', str((idx + 1) // k))
|
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
return gm
|
return gm
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
import copy
|
import copy
|
||||||
|
import re
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torchvision.models as tm
|
import torchvision.models as tm
|
||||||
|
@ -7,7 +8,7 @@ from torch.fx import GraphModule
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.fx import ColoTracer
|
from colossalai.fx import ColoTracer
|
||||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
from colossalai.fx.passes.algorithms import chen_greedy, chen_sqrtn
|
from colossalai.fx.passes.algorithms import chen_greedy
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -20,7 +21,7 @@ except:
|
||||||
from colossalai.fx.codegen import python_code_with_activation_checkpoint
|
from colossalai.fx.codegen import python_code_with_activation_checkpoint
|
||||||
with_codegen = False
|
with_codegen = False
|
||||||
|
|
||||||
SOLVERS = [chen_greedy, chen_sqrtn]
|
SOLVERS = [chen_greedy]
|
||||||
|
|
||||||
|
|
||||||
def _is_activation_checkpoint_available(gm: GraphModule):
|
def _is_activation_checkpoint_available(gm: GraphModule):
|
||||||
|
@ -36,6 +37,16 @@ def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _is_graph_linearized(gm: GraphModule):
|
||||||
|
code = gm.code
|
||||||
|
# find patterns like r' return output_1, output_2', which is not expected on a linearized graph
|
||||||
|
pattern = re.compile(r' return [a-zA-Z0-9_]+(, [a-zA-Z0-9_]+)+')
|
||||||
|
if pattern.findall(code):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule],
|
def check_backward_consistency(m: torch.nn.Module, gm: GraphModule, solver: Callable[[GraphModule], GraphModule],
|
||||||
model_cls: Callable[[], torch.nn.Module]):
|
model_cls: Callable[[], torch.nn.Module]):
|
||||||
criterion = torch.nn.MSELoss()
|
criterion = torch.nn.MSELoss()
|
||||||
|
@ -66,12 +77,13 @@ def _run_ckpt_solver(rank):
|
||||||
codegen = ActivationCheckpointCodeGen()
|
codegen = ActivationCheckpointCodeGen()
|
||||||
gm.graph.set_codegen(codegen)
|
gm.graph.set_codegen(codegen)
|
||||||
gm = solver(gm)
|
gm = solver(gm)
|
||||||
|
assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner."
|
||||||
assert _is_activation_checkpoint_available(
|
assert _is_activation_checkpoint_available(
|
||||||
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
|
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
|
||||||
check_backward_consistency(m, gm, solver, model_cls)
|
check_backward_consistency(m, gm, solver, model_cls)
|
||||||
|
gpc.destroy()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
|
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
|
||||||
def test_ckpt_solver():
|
def test_ckpt_solver():
|
||||||
mp.spawn(_run_ckpt_solver, nprocs=1)
|
mp.spawn(_run_ckpt_solver, nprocs=1)
|
||||||
|
@ -94,12 +106,13 @@ def _run_ckpt_solver_torch11(rank):
|
||||||
MetaInfoProp(gm).run(data)
|
MetaInfoProp(gm).run(data)
|
||||||
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
|
gm.graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
|
||||||
gm = solver(gm)
|
gm = solver(gm)
|
||||||
|
assert _is_graph_linearized(gm), f"Solver {solver} did not solve {model_cls} in a linearized manner."
|
||||||
assert _is_activation_checkpoint_available(
|
assert _is_activation_checkpoint_available(
|
||||||
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
|
gm), f"Solver {solver} did not annotate {model_cls} with any activation checkpoints"
|
||||||
check_backward_consistency(m, gm, solver, model_cls)
|
check_backward_consistency(m, gm, solver, model_cls)
|
||||||
|
gpc.destroy()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
|
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
|
||||||
def test_ckpt_solver_torch11():
|
def test_ckpt_solver_torch11():
|
||||||
mp.spawn(_run_ckpt_solver_torch11, nprocs=1)
|
mp.spawn(_run_ckpt_solver_torch11, nprocs=1)
|
||||||
|
|
Loading…
Reference in New Issue