From 092b9c8f49375d2603f021680ab192df64664c4c Mon Sep 17 00:00:00 2001 From: Boyuan Yao <70263930+Cypher30@users.noreply.github.com> Date: Wed, 17 Aug 2022 10:34:50 +0800 Subject: [PATCH] [fx] Add use_reentrant=False to checkpoint in codegen (#1463) * [utils] Add use_reetrant=False into colossalai checkpoint * [utils] add some annotation in utils.activaion_checkpoint * [test] add reset_seed at the beginning of tests in test_actiavion_checkpointing.py * [test] modify test_activation_checkpoint.py * [test] modify test for reentrant=False * [fx] Add use_reentrant=False of checkpoint into codegen --- .../codegen/activation_checkpoint_codegen.py | 22 +++++++-- .../test_activation_checkpoint_codegen.py | 49 +++++++++++++------ 2 files changed, 53 insertions(+), 18 deletions(-) diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index 4a4bbef4c..53eb46529 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -99,13 +99,13 @@ def _gen_ckpt_output(output_vars: List[str]) -> str: return f"return {', '.join(output_vars)}" -def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars): +def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reentrant=True): """ Generate the checkpoint function call code text """ outputs = ', '.join(output_vars) inputs = ', '.join(input_vars) - return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs})' + return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})' def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unused_value_func): @@ -162,8 +162,24 @@ def emit_code_with_activation_checkpoint(body, nodes, emit_node_func, delete_unu else: activation_offload = False + # we need to check if the checkpoint need use_reentrant=False + use_reentrant = True + for var in input_vars[label]: + input_node = [item for item in node_list if item.name == var] + input_node = input_node[0] + for user in input_node.users: + if hasattr(user, "activation_checkpoint"): + if user.activation_checkpoint == label: + if user.op == "call_module": + if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"): + use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace + + elif user.op == "call_function": + if "inplace" in user.kwargs: + use_reentrant = not user.kwargs["inplace"] + # generate checkpoint function call in a new line - usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label]) + usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant) usage += '\n' body.append(usage) within_ckpt_region = False diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py index fe5c638b2..9c1bc57a3 100644 --- a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py @@ -1,5 +1,6 @@ from operator import mod import torch +import torch.nn.functional as F import pytest import torch.multiprocessing as mp from torch.utils.checkpoint import checkpoint @@ -26,7 +27,17 @@ class MLP(torch.nn.Module): self.linear2 = torch.nn.Linear(4, 4) def forward(self, x): - return self.linear1(x), self.linear1(x) + return self.linear1(x), self.linear2(x) + + +class relu(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + self.relu = torch.nn.ReLU(inplace=True) + + def forward(self, x): + return self.relu(x) class MyModule(torch.nn.Module): @@ -34,12 +45,17 @@ class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.mlp1 = MLP() - self.mlp2 = MLP() + self.relu = relu() self.linear3 = torch.nn.Linear(4, 4) def forward(self, x): y1, y2 = checkpoint(self.mlp1, x) - y3, y4 = checkpoint(self.mlp2, x) + y3 = checkpoint(self.relu, x) + + def ckpt2(x): + return F.relu(x, inplace=True) + + y4 = checkpoint(ckpt2, x) return y1 + y2 + y3 + y4 @@ -65,8 +81,8 @@ def _run_act_ckpt_codegen(rank): # check ops are annotated with ckpt # also annotate the selected node for offloading - ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1'] - offload_starts = ['mlp2_linear1'] + ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu'] + offload_starts = ['mlp1_linear1'] for node in graph.nodes: if node.name in ckpt_nodes: assert hasattr(node, 'activation_checkpoint') @@ -75,15 +91,17 @@ def _run_act_ckpt_codegen(rank): if node.name in offload_starts: setattr(node, 'activation_offload', True) + gm = GraphModule(model, graph) + gm.recompile() + # assert checkpoint function will be generated and # the offload option is correct code = graph.python_code('self').src - assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, x)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, True, x)' in code + assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, x, use_reentrant=True)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, x, use_reentrant=False)' in code # recompile and verify the outputs are consistent - gm = GraphModule(model, graph) - gm.recompile() fx_out = gm(data) assert torch.equal(non_fx_out, fx_out) @@ -117,8 +135,8 @@ def _run_act_ckpt_python_code_torch11(rank): graph._python_code = python_code_with_activation_checkpoint.__get__(graph) # check ops are annotated with ckpt - ckpt_nodes = ['mlp1_linear1', 'mlp1_linear1_1', 'mlp2_linear1', 'mlp2_linear1_1'] - offload_starts = ['mlp2_linear1'] + ckpt_nodes = ['mlp1_linear1', 'mlp1_linear2', 'relu_relu', 'relu'] + offload_starts = ['mlp1_linear1'] for node in graph.nodes: if node.name in ckpt_nodes: assert hasattr(node, 'activation_checkpoint') @@ -127,15 +145,16 @@ def _run_act_ckpt_python_code_torch11(rank): if node.name in offload_starts: setattr(node, 'activation_offload', True) + gm = GraphModule(model, graph) + gm.recompile() # assert checkpoint function will be generated and # the offload option is correct code = graph.python_code('self').src - assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, False, x)' in code and \ - 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, True, x)' in code + assert 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_0, True, x, use_reentrant=True)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_1, False, x, use_reentrant=False)' in code and \ + 'colossalai.utils.activation_checkpoint.checkpoint(checkpoint_2, False, x, use_reentrant=False)' in code # recompile and verify the outputs are consistent - gm = GraphModule(model, graph) - gm.recompile() fx_out = gm(data) assert torch.equal(non_fx_out, fx_out)