[fx] Use colossalai checkpoint and add offload recognition in codegen (#1439)

* [fx] Use colossalai.utils.checkpoint to replace torch.utils.checkpoint for offload activation and add offload annotation recognition in codegen

* [fx] Use colossalai.utils.checkpoint to replace torch.utils.checkpoint for offload activation and add offload annotation recognition in codegen

* Modification of test and add TODO in codegen

* [fx] Modification of colossal ckpt usage

* [fx] add gpc.destroy() to test_codegen
pull/1450/head
Boyuan Yao 2022-08-12 12:23:30 +08:00 committed by GitHub
parent e9460b45c8
commit 5774fe0270
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 63 additions and 7 deletions

View File

@ -99,13 +99,13 @@ def _gen_ckpt_output(output_vars: List[str]) -> str:
return f"return {', '.join(output_vars)}" return f"return {', '.join(output_vars)}"
def _gen_ckpt_usage(label, input_vars, output_vars): def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars):
""" """
Generate the checkpoint function call code text Generate the checkpoint function call code text
""" """
outputs = ', '.join(output_vars) outputs = ', '.join(output_vars)
inputs = ', '.join(input_vars) inputs = ', '.join(input_vars)
return f'{outputs} = torch.utils.checkpoint.checkpoint(checkpoint_{label}, {inputs})' return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs})'
def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unused_value_func): def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unused_value_func):
@ -155,8 +155,15 @@ def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unu
return_statement = f' {return_statement}\n' return_statement = f' {return_statement}\n'
body.append(return_statement) body.append(return_statement)
# we need to check if the checkpoint need to offload the input
start_node_idx = start_idx[label]
if hasattr(node_list[start_node_idx], 'activation_offload'):
activation_offload = node_list[start_node_idx].activation_offload
else:
activation_offload = False
# generate checkpoint function call in a new line # generate checkpoint function call in a new line
usage = _gen_ckpt_usage(label, input_vars[label], output_vars[label]) usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label])
usage += '\n' usage += '\n'
body.append(usage) body.append(usage)
within_ckpt_region = False within_ckpt_region = False
@ -368,7 +375,11 @@ if codegen_available:
for name, value in self.additional_globals(): for name, value in self.additional_globals():
add_global(name, value) add_global(name, value)
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
# TODO: Remove inline import
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
prologue = prologue + "\n import colossalai"
code = ''.join(body) code = ''.join(body)
code = '\n'.join(' ' + line for line in code.split('\n')) code = '\n'.join(' ' + line for line in code.split('\n'))
@ -566,9 +577,14 @@ else:
orig_args.insert(0, 'self') orig_args.insert(0, 'self')
code = ''.join(body) code = ''.join(body)
code = '\n'.join(' ' + line for line in code.split('\n')) code = '\n'.join(' ' + line for line in code.split('\n'))
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
# TODO: Remove inline import
fn_code = f""" fn_code = f"""
{wrap_stmts} {wrap_stmts}
def forward({', '.join(orig_args)}){maybe_return_annotation[0]}: def forward({', '.join(orig_args)}){maybe_return_annotation[0]}:
import colossalai
{code}""" {code}"""
return PythonCode(fn_code, globals_) return PythonCode(fn_code, globals_)

View File

@ -1,8 +1,12 @@
from operator import mod
import torch import torch
import pytest import pytest
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from torch.fx import GraphModule from torch.fx import GraphModule
from colossalai.fx import ColoTracer from colossalai.fx import ColoTracer
import colossalai
from colossalai.utils import free_port
from colossalai.core import global_context as gpc
try: try:
from colossalai.fx.codegen import ActivationCheckpointCodeGen from colossalai.fx.codegen import ActivationCheckpointCodeGen
@ -40,9 +44,17 @@ class MyModule(torch.nn.Module):
@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0')
def test_act_ckpt_codegen(): def test_act_ckpt_codegen():
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
# build model and run forward # build model and run forward
model = MyModule() model = MyModule()
data = torch.rand(4, 4) data = torch.rand(4, 4)
# copy model to cuda
model = model.to(device="cuda")
data = data.to(device="cuda")
non_fx_out = model(data) non_fx_out = model(data)
# trace the module and replace codegen # trace the module and replace codegen
@ -52,14 +64,22 @@ def test_act_ckpt_codegen():
graph.set_codegen(codegen) graph.set_codegen(codegen)
# check ops are annotated with ckpt # check ops are annotated with ckpt
# also annotate the selected node for offloading
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1'] ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1']
offload_starts = ['mlp2_linear1']
for node in graph.nodes: for node in graph.nodes:
if node.name in ckpt_nodes: if node.name in ckpt_nodes:
assert hasattr(node, 'activation_checkpoint') assert hasattr(node, 'activation_checkpoint')
# assert checkpoint function will be generated # annotate the selected node for offload
if node.name in offload_starts:
setattr(node, 'activation_offload', True)
# assert checkpoint function will be generated and
# the offload option is correct
code = graph.python_code('self').src code = graph.python_code('self').src
assert 'checkpoint_0' in code and 'checkpoint_1' in code assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, x)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, True, x)' in code
# recompile and verify the outputs are consistent # recompile and verify the outputs are consistent
gm = GraphModule(model, graph) gm = GraphModule(model, graph)
@ -67,12 +87,22 @@ def test_act_ckpt_codegen():
fx_out = gm(data) fx_out = gm(data)
assert torch.equal(non_fx_out, fx_out) assert torch.equal(non_fx_out, fx_out)
gpc.destroy()
@pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') @pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0')
def test_act_ckpt_python_code_torch11(): def test_act_ckpt_python_code_torch11():
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
# build model and run forward # build model and run forward
model = MyModule() model = MyModule()
data = torch.rand(4, 4) data = torch.rand(4, 4)
# copy model to cuda
model = model.to(device="cuda")
data = data.to(device="cuda")
non_fx_out = model(data) non_fx_out = model(data)
# trace the module and replace codegen # trace the module and replace codegen
@ -84,13 +114,20 @@ def test_act_ckpt_python_code_torch11():
# check ops are annotated with ckpt # check ops are annotated with ckpt
ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1'] ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1']
offload_starts = ['mlp2_linear1']
for node in graph.nodes: for node in graph.nodes:
if node.name in ckpt_nodes: if node.name in ckpt_nodes:
assert hasattr(node, 'activation_checkpoint') assert hasattr(node, 'activation_checkpoint')
# assert checkpoint function will be generated # annotate the selected node for offload
if node.name in offload_starts:
setattr(node, 'activation_offload', True)
# assert checkpoint function will be generated and
# the offload option is correct
code = graph.python_code('self').src code = graph.python_code('self').src
assert 'checkpoint_0' in code and 'checkpoint_1' in code assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, x)' in code and \
'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, True, x)' in code
# recompile and verify the outputs are consistent # recompile and verify the outputs are consistent
gm = GraphModule(model, graph) gm = GraphModule(model, graph)
@ -98,7 +135,10 @@ def test_act_ckpt_python_code_torch11():
fx_out = gm(data) fx_out = gm(data)
assert torch.equal(non_fx_out, fx_out) assert torch.equal(non_fx_out, fx_out)
gpc.destroy()
if __name__ == '__main__': if __name__ == '__main__':
test_act_ckpt_codegen() test_act_ckpt_codegen()
test_act_ckpt_python_code_torch11() test_act_ckpt_python_code_torch11()