mirror of https://github.com/hpcaitech/ColossalAI
[fx] Add use_reentrant=False to checkpoint in codegen (#1463)
* [utils] Add use_reetrant=False into colossalai checkpoint * [utils] add some annotation in utils.activaion_checkpoint * [test] add reset_seed at the beginning of tests in test_actiavion_checkpointing.py * [test] modify test_activation_checkpoint.py * [test] modify test for reentrant=False * [fx] Add use_reentrant=False of checkpoint into codegenpull/1461/head
parent
47fd8e4a02
commit
092b9c8f49
|
@ -99,13 +99,13 @@ def _gen_ckpt_output(output_vars: List[str]) -> str:
|
||||||
return f"return {', '.join(output_vars)}"
|
return f"return {', '.join(output_vars)}"
|
||||||
|
|
||||||
|
|
||||||
def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars):
|
def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reentrant=True):
|
||||||
"""
|
"""
|
||||||
Generate the checkpoint function call code text
|
Generate the checkpoint function call code text
|
||||||
"""
|
"""
|
||||||
outputs = ', '.join(output_vars)
|
outputs = ', '.join(output_vars)
|
||||||
inputs = ', '.join(input_vars)
|
inputs = ', '.join(input_vars)
|
||||||
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs})'
|
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})'
|
||||||
|
|
||||||
|
|
||||||
def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unused_value_func):
|
def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unused_value_func):
|
||||||
|
@ -162,8 +162,24 @@ def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unu
|
||||||
else:
|
else:
|
||||||
activation_offload = False
|
activation_offload = False
|
||||||
|
|
||||||
|
# we need to check if the checkpoint need use_reentrant=False
|
||||||
|
use_reentrant = True
|
||||||
|
for var in input_vars[label]:
|
||||||
|
input_node = [item for item in node_list if item.name == var]
|
||||||
|
input_node = input_node[0]
|
||||||
|
for user in input_node.users:
|
||||||
|
if hasattr(user, "activation_checkpoint"):
|
||||||
|
if user.activation_checkpoint == label:
|
||||||
|
if user.op == "call_module":
|
||||||
|
if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"):
|
||||||
|
use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace
|
||||||
|
|
||||||
|
elif user.op == "call_function":
|
||||||
|
if "inplace" in user.kwargs:
|
||||||
|
use_reentrant = not user.kwargs["inplace"]
|
||||||
|
|
||||||
# generate checkpoint function call in a new line
|
# generate checkpoint function call in a new line
|
||||||
usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label])
|
usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant)
|
||||||
usage += '\n'
|
usage += '\n'
|
||||||
body.append(usage)
|
body.append(usage)
|
||||||
within_ckpt_region = False
|
within_ckpt_region = False
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from operator import mod
|
from operator import mod
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import pytest
|
import pytest
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from torch.utils.checkpoint import checkpoint
|
from torch.utils.checkpoint import checkpoint
|
||||||
|
@ -26,7 +27,17 @@ class MLP(torch.nn.Module):
|
||||||
self.linear2 = torch.nn.Linear(4, 4)
|
self.linear2 = torch.nn.Linear(4, 4)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.linear1(x), self.linear1(x)
|
return self.linear1(x), self.linear2(x)
|
||||||
|
|
||||||
|
|
||||||
|
class relu(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.relu = torch.nn.ReLU(inplace=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.relu(x)
|
||||||
|
|
||||||
|
|
||||||
class MyModule(torch.nn.Module):
|
class MyModule(torch.nn.Module):
|
||||||
|
@ -34,12 +45,17 @@ class MyModule(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.mlp1 = MLP()
|
self.mlp1 = MLP()
|
||||||
self.mlp2 = MLP()
|
self.relu = relu()
|
||||||
self.linear3 = torch.nn.Linear(4, 4)
|
self.linear3 = torch.nn.Linear(4, 4)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y1, y2 = checkpoint(self.mlp1, x)
|
y1, y2 = checkpoint(self.mlp1, x)
|
||||||
y3, y4 = checkpoint(self.mlp2, x)
|
y3 = checkpoint(self.relu, x)
|
||||||
|
|
||||||
|
def ckpt2(x):
|
||||||
|
return F.relu(x, inplace=True)
|
||||||
|
|
||||||
|
y4 = checkpoint(ckpt2, x)
|
||||||
return y1 + y2 + y3 + y4
|
return y1 + y2 + y3 + y4
|
||||||
|
|
||||||
|
|
||||||
|
@ -65,8 +81,8 @@ def _run_act_ckpt_codegen(rank):
|
||||||
|
|
||||||
# check ops are annotated with ckpt
|
# check ops are annotated with ckpt
|
||||||
# also annotate the selected node for offloading
|
# also annotate the selected node for offloading
|
||||||
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1']
|
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu']
|
||||||
offload_starts = ['mlp2_linear1']
|
offload_starts = ['mlp1_linear1']
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
if node.name in ckpt_nodes:
|
if node.name in ckpt_nodes:
|
||||||
assert hasattr(node, 'activation_checkpoint')
|
assert hasattr(node, 'activation_checkpoint')
|
||||||
|
@ -75,15 +91,17 @@ def _run_act_ckpt_codegen(rank):
|
||||||
if node.name in offload_starts:
|
if node.name in offload_starts:
|
||||||
setattr(node, 'activation_offload', True)
|
setattr(node, 'activation_offload', True)
|
||||||
|
|
||||||
|
gm = GraphModule(model, graph)
|
||||||
|
gm.recompile()
|
||||||
|
|
||||||
# assert checkpoint function will be generated and
|
# assert checkpoint function will be generated and
|
||||||
# the offload option is correct
|
# the offload option is correct
|
||||||
code = graph.python_code('self').src
|
code = graph.python_code('self').src
|
||||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, x)' in code and \
|
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, x, use_reentrant=True)' in code and \
|
||||||
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, True, x)' in code
|
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, x, use_reentrant=False)' in code and \
|
||||||
|
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, x, use_reentrant=False)' in code
|
||||||
|
|
||||||
# recompile and verify the outputs are consistent
|
# recompile and verify the outputs are consistent
|
||||||
gm = GraphModule(model, graph)
|
|
||||||
gm.recompile()
|
|
||||||
fx_out = gm(data)
|
fx_out = gm(data)
|
||||||
assert torch.equal(non_fx_out, fx_out)
|
assert torch.equal(non_fx_out, fx_out)
|
||||||
|
|
||||||
|
@ -117,8 +135,8 @@ def _run_act_ckpt_python_code_torch11(rank):
|
||||||
graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
|
graph._python_code = python_code_with_activation_checkpoint.__get__(graph)
|
||||||
|
|
||||||
# check ops are annotated with ckpt
|
# check ops are annotated with ckpt
|
||||||
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1']
|
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu']
|
||||||
offload_starts = ['mlp2_linear1']
|
offload_starts = ['mlp1_linear1']
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
if node.name in ckpt_nodes:
|
if node.name in ckpt_nodes:
|
||||||
assert hasattr(node, 'activation_checkpoint')
|
assert hasattr(node, 'activation_checkpoint')
|
||||||
|
@ -127,15 +145,16 @@ def _run_act_ckpt_python_code_torch11(rank):
|
||||||
if node.name in offload_starts:
|
if node.name in offload_starts:
|
||||||
setattr(node, 'activation_offload', True)
|
setattr(node, 'activation_offload', True)
|
||||||
|
|
||||||
|
gm = GraphModule(model, graph)
|
||||||
|
gm.recompile()
|
||||||
# assert checkpoint function will be generated and
|
# assert checkpoint function will be generated and
|
||||||
# the offload option is correct
|
# the offload option is correct
|
||||||
code = graph.python_code('self').src
|
code = graph.python_code('self').src
|
||||||
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, x)' in code and \
|
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, x, use_reentrant=True)' in code and \
|
||||||
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, True, x)' in code
|
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, x, use_reentrant=False)' in code and \
|
||||||
|
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, x, use_reentrant=False)' in code
|
||||||
|
|
||||||
# recompile and verify the outputs are consistent
|
# recompile and verify the outputs are consistent
|
||||||
gm = GraphModule(model, graph)
|
|
||||||
gm.recompile()
|
|
||||||
fx_out = gm(data)
|
fx_out = gm(data)
|
||||||
assert torch.equal(non_fx_out, fx_out)
|
assert torch.equal(non_fx_out, fx_out)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue