From 5774fe02706427532597acb94147d1134a40c2f8 Mon Sep 17 00:00:00 2001 From: Boyuan Yao <70263930+Cypher30@users.noreply.github.com> Date: Fri, 12 Aug 2022 12:23:30 +0800 Subject: [PATCH] [fx] Use colossalai checkpoint and add offload recognition in codegen (#1439) * [fx] Use colossalai.utils.checkpoint to replace torch.utils.checkpoint for offload activation and add offload annotation recognition in codegen * [fx] Use colossalai.utils.checkpoint to replace torch.utils.checkpoint for offload activation and add offload annotation recognition in codegen * Modification of test and add TODO in codegen * [fx] Modification of colossal ckpt usage * [fx] add gpc.destroy() to test_codegen --- .../codegen/activation_checkpoint_codegen.py | 22 +++++++-- .../test_activation_checkpoint_codegen.py | 48 +++++++++++++++++-- 2 files changed, 63 insertions(+), 7 deletions(-) diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index 1bbb1418f..11e194fc0 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -99,13 +99,13 @@ def _gen_ckpt_output(output_vars: List[str]) -> str: return f"return {', '.join(output_vars)}" -def _gen_ckpt_usage(label, input_vars, output_vars): +def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars): """ Generate the checkpoint function call code text """ outputs = ', '.join(output_vars) inputs = ', '.join(input_vars) - return f'{outputs} = torch.utils.checkpoint.checkpoint(checkpoint_{label}, {inputs})' + return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs})' def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unused_value_func): @@ -155,8 +155,15 @@ def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unu return_statement = f' {return_statement}\n' body.append(return_statement) + # we need to check if the checkpoint need to offload the input + start_node_idx = start_idx[label] + if hasattr(node_list[start_node_idx], 'activation_offload'): + activation_offload = node_list[start_node_idx].activation_offload + else: + activation_offload = False + # generate checkpoint function call in a new line - usage = _gen_ckpt_usage(label, input_vars[label], output_vars[label]) + usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label]) usage += '\n' body.append(usage) within_ckpt_region = False @@ -368,7 +375,11 @@ if codegen_available: for name, value in self.additional_globals(): add_global(name, value) + # as we need colossalai.utils.checkpoint, we need to import colossalai + # in forward function + # TODO: Remove inline import prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) + prologue = prologue + "\n import colossalai" code = ''.join(body) code = '\n'.join(' ' + line for line in code.split('\n')) @@ -566,9 +577,14 @@ else: orig_args.insert(0, 'self') code = ''.join(body) code = '\n'.join(' ' + line for line in code.split('\n')) + + # as we need colossalai.utils.checkpoint, we need to import colossalai + # in forward function + # TODO: Remove inline import fn_code = f""" {wrap_stmts} def forward({', '.join(orig_args)}){maybe_return_annotation[0]}: + import colossalai {code}""" return PythonCode(fn_code, globals_) diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py index 302307776..411ec0083 100644 --- a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py @@ -1,8 +1,12 @@ +from operator import mod import torch import pytest from torch.utils.checkpoint import checkpoint from torch.fx import GraphModule from colossalai.fx import ColoTracer +import colossalai +from colossalai.utils import free_port +from colossalai.core import global_context as gpc try: from colossalai.fx.codegen import ActivationCheckpointCodeGen @@ -40,9 +44,17 @@ class MyModule(torch.nn.Module): @pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') def test_act_ckpt_codegen(): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') + # build model and run forward model = MyModule() data = torch.rand(4, 4) + + # copy model to cuda + model = model.to(device="cuda") + data = data.to(device="cuda") + non_fx_out = model(data) # trace the module and replace codegen @@ -52,14 +64,22 @@ def test_act_ckpt_codegen(): graph.set_codegen(codegen) # check ops are annotated with ckpt + # also annotate the selected node for offloading ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1'] + offload_starts = ['mlp2_linear1'] for node in graph.nodes: if node.name in ckpt_nodes: assert hasattr(node, 'activation_checkpoint') - # assert checkpoint function will be generated + # annotate the selected node for offload + if node.name in offload_starts: + setattr(node, 'activation_offload', True) + + # assert checkpoint function will be generated and + # the offload option is correct code = graph.python_code('self').src - assert 'checkpoint_0' in code and 'checkpoint_1' in code + assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, x)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, True, x)' in code # recompile and verify the outputs are consistent gm = GraphModule(model, graph) @@ -67,12 +87,22 @@ def test_act_ckpt_codegen(): fx_out = gm(data) assert torch.equal(non_fx_out, fx_out) + gpc.destroy() + @pytest.mark.skipif(with_codegen, reason='torch version is equal to or higher than 1.12.0') def test_act_ckpt_python_code_torch11(): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') + # build model and run forward model = MyModule() data = torch.rand(4, 4) + + # copy model to cuda + model = model.to(device="cuda") + data = data.to(device="cuda") + non_fx_out = model(data) # trace the module and replace codegen @@ -84,13 +114,20 @@ def test_act_ckpt_python_code_torch11(): # check ops are annotated with ckpt ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1'] + offload_starts = ['mlp2_linear1'] for node in graph.nodes: if node.name in ckpt_nodes: assert hasattr(node, 'activation_checkpoint') - # assert checkpoint function will be generated + # annotate the selected node for offload + if node.name in offload_starts: + setattr(node, 'activation_offload', True) + + # assert checkpoint function will be generated and + # the offload option is correct code = graph.python_code('self').src - assert 'checkpoint_0' in code and 'checkpoint_1' in code + assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, x)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, True, x)' in code # recompile and verify the outputs are consistent gm = GraphModule(model, graph) @@ -98,7 +135,10 @@ def test_act_ckpt_python_code_torch11(): fx_out = gm(data) assert torch.equal(non_fx_out, fx_out) + gpc.destroy() + if __name__ == '__main__': + test_act_ckpt_codegen() test_act_ckpt_python_code_torch11()