[fx] Use colossalai checkpoint and add offload recognition in codegen (#1439)

* [fx] Use colossalai.utils.checkpoint to replace torch.utils.checkpoint for offload activation and add offload annotation recognition in codegen

* [fx] Use colossalai.utils.checkpoint to replace torch.utils.checkpoint for offload activation and add offload annotation recognition in codegen

* Modification of test and add TODO in codegen

* [fx] Modification of colossal ckpt usage

* [fx] add gpc.destroy() to test_codegen
pull/1450/head
Boyuan Yao 2022-08-12 12:23:30 +08:00 committed by GitHub
parent e9460b45c8
commit 5774fe0270
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 63 additions and 7 deletions

View File

@ -99,13 +99,13 @@ def _gen_ckpt_output(output_vars: List[str]) -> str:
return f"return {', '.join(output_vars)}"
def _gen_ckpt_usage(label, input_vars, output_vars):
def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars):
"""
Generate the checkpoint function call code text
"""
outputs = ', '.join(output_vars)
inputs = ', '.join(input_vars)
return f'{outputs} = torch.utils.checkpoint.checkpoint(checkpoint_{label}, {inputs})'
return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs})'
def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unused_value_func):
@ -155,8 +155,15 @@ def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unu
return_statement = f' {return_statement}\n'
body.append(return_statement)
# we need to check if the checkpoint need to offload the input
start_node_idx = start_idx[label]
if hasattr(node_list[start_node_idx], 'activation_offload'):
activation_offload = node_list[start_node_idx].activation_offload
else:
activation_offload = False
# generate checkpoint function call in a new line
usage = _gen_ckpt_usage(label, input_vars[label], output_vars[label])
usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label])
usage += '\n'
body.append(usage)
within_ckpt_region = False
@ -368,7 +375,11 @@ if codegen_available:
for name, value in self.additional_globals():
add_global(name, value)
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
# TODO: Remove inline import
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
prologue = prologue + "\n import colossalai"
code = ''.join(body)
code = '\n'.join(' ' + line for line in code.split('\n'))
@ -566,9 +577,14 @@ else:
orig_args.insert(0, 'self')
code = ''.join(body)
code = '\n'.join(' ' + line for line in code.split('\n'))
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
# TODO: Remove inline import
fn_code = f"""
{wrap_stmts}
def forward({', '.join(orig_args)}){maybe_return_annotation[0]}:
import colossalai
{code}"""
return PythonCode(fn_code, globals_)

View File

@ -1,8 +1,12 @@
from operator import mod
import torch
import pytest
from torch.utils.checkpoint import checkpoint
from torch.fx import GraphModule
from colossalai.fx import ColoTracer
import colossalai
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen
@ -40,9 +44,17 @@ class MyModule(torch.nn.Module):
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
def test_act_ckpt_codegen():
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
# build model and run forward
model = MyModule()
data = torch.rand(4, 4)
# copy model to cuda
model = model.to(device="cuda")
data = data.to(device="cuda")
non_fx_out = model(data)
# trace the module and replace codegen
@ -52,14 +64,22 @@ def test_act_ckpt_codegen():
graph.set_codegen(codegen)
# check ops are annotated with ckpt
# also annotate the selected node for offloading
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1']
offload_starts = ['mlp2_linear1']
for node in graph.nodes:
if node.name in ckpt_nodes:
assert hasattr(node, 'activation_checkpoint')
# assert checkpoint function will be generated
# annotate the selected node for offload
if node.name in offload_starts:
setattr(node, 'activation_offload', True)
# assert checkpoint function will be generated and
# the offload option is correct
code = graph.python_code('self').src
assert 'checkpoint_0' in code and 'checkpoint_1' in code
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, x)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, True, x)' in code
# recompile and verify the outputs are consistent
gm = GraphModule(model, graph)
@ -67,12 +87,22 @@ def test_act_ckpt_codegen():
fx_out = gm(data)
assert torch.equal(non_fx_out, fx_out)
gpc.destroy()
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
def test_act_ckpt_python_code_torch11():
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
# build model and run forward
model = MyModule()
data = torch.rand(4, 4)
# copy model to cuda
model = model.to(device="cuda")
data = data.to(device="cuda")
non_fx_out = model(data)
# trace the module and replace codegen
@ -84,13 +114,20 @@ def test_act_ckpt_python_code_torch11():
# check ops are annotated with ckpt
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1']
offload_starts = ['mlp2_linear1']
for node in graph.nodes:
if node.name in ckpt_nodes:
assert hasattr(node, 'activation_checkpoint')
# assert checkpoint function will be generated
# annotate the selected node for offload
if node.name in offload_starts:
setattr(node, 'activation_offload', True)
# assert checkpoint function will be generated and
# the offload option is correct
code = graph.python_code('self').src
assert 'checkpoint_0' in code and 'checkpoint_1' in code
assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, x)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, True, x)' in code
# recompile and verify the outputs are consistent
gm = GraphModule(model, graph)
@ -98,7 +135,10 @@ def test_act_ckpt_python_code_torch11():
fx_out = gm(data)
assert torch.equal(non_fx_out, fx_out)
gpc.destroy()
if __name__ == '__main__':
test_act_ckpt_codegen()
test_act_ckpt_python_code_torch11()