diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index e622b8d3a..ad143a690 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -17,6 +17,38 @@ else: __all__ = ['python_code_with_activation_checkpoint'] +def _gen_saved_tensors_hooks(): + """ + Generate saved tensors hooks + """ + + pack_hook = """def pack_hook(self, x): + if getattr(x, "offload", None): + return (x.device, x.cpu()) + else: + return x +""" + + unpack_hook = """def unpack_hook(self, packed): + if isinstance(packed, tuple): + device, tensor = packed + return tensor.to(device) + else: + return packed +""" + + return pack_hook, unpack_hook + + +def _gen_save_tensors_hooks_context(): + """ + Generate save tensors hooks context + """ + + context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):\n" + return context + + def _find_input_and_output_nodes(nodes: List[Node]): """ Find the input and output node names which are not found in the given list of nodes. @@ -211,7 +243,7 @@ def emit_ckpt_func(body, ckpt_func[-1] = ' ' + ckpt_func[-1] delete_unused_value_func(node, ckpt_func) - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n') + ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') activation_offload = getattr(node_list[0], "activation_offload", False) usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) usage += "\n" @@ -251,7 +283,7 @@ def emit_ckpt_func(body, delete_unused_value_func(node, ckpt_func) node_idx += 1 - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n') + ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') ckpt_func += ckpt_func_buffer activation_offload = getattr(node_list[0], "activation_offload", False) usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' @@ -266,7 +298,7 @@ def emit_ckpt_func(body, ckpt_func[-1] = ' ' + ckpt_func[-1] delete_unused_value_func(node, ckpt_func) - ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n') + ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') activation_offload = getattr(node_list[0], "activation_offload", False) usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' if in_ckpt: @@ -292,6 +324,9 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod node_list = list(nodes) + # this flag is to prevent repeated insert of save tensors + # hooks definition in ckpt_func + is_hook_inserted = False node_idx = 0 while 1: # break if we finish the processing all the nodes @@ -307,8 +342,27 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod # process node in forward function else: node = node_list[node_idx] - emit_node_func(node, body) - delete_unused_value_func(node, body) + + # if a node is outside of checkpoint region and want to offload + # it's input activation, we will use torch.saved_tensors_hooks + # to complete the offload process. + if getattr(node, "activation_offload", False): + if not is_hook_inserted: + pack_hook, unpack_hook = _gen_saved_tensors_hooks() + ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n") + + for par in node.all_input_nodes: + # annotate the input tensor for pack hook + body.append(f"setattr({repr(par)}, 'offload', True)\n") + + body.append(_gen_save_tensors_hooks_context()) + emit_node_func(node, body) + body[-1] = ' ' + body[-1] + delete_unused_value_func(node, body) + + else: + emit_node_func(node, body) + delete_unused_value_func(node, body) node_idx += 1 @@ -323,6 +377,10 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, node_list = list(nodes) + # use this variable to avoid inserting hook functions + # to ckpt_func repeatedly + is_hook_inserted = False + # find the input and output var names for each region for idx, (start, end) in enumerate(ckpt_regions): ckpt_node_list = node_list[start:end + 1] @@ -348,8 +406,26 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, ckpt_func[-1] = ' ' + ckpt_func[-1] delete_unused_value_func(node, ckpt_func) else: - emit_node_func(node, body) - delete_unused_value_func(node, body) + # if a node is outside of checkpoint region wants to offload + # it's input activation, we will use torch.saved_tensors_hooks + # to complete the offload process. + if getattr(node, "activation_offload", False): + if not is_hook_inserted: + pack_hook, unpack_hook = _gen_saved_tensors_hooks() + ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n") + + for par in node.all_input_nodes: + # annotate the input tensor for pack hook + body.append(f"setattr({repr(par)}, 'offload', True)\n") + + body.append(_gen_save_tensors_hooks_context()) + emit_node_func(node, body) + body[-1] = ' ' + body[-1] + delete_unused_value_func(node, body) + + else: + emit_node_func(node, body) + delete_unused_value_func(node, body) if idx in end_idx: # if this is the last node of the ckpt region @@ -587,10 +663,13 @@ if CODEGEN_AVAILABLE: # Modified for activation checkpointing ckpt_func = [] - if all(not isinstance(getattr(node, "activation_checkpoint", None), list) for node in nodes): - emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) - else: + + # if any node has a list of labels for activation_checkpoint, we + # will use nested type of activation checkpoint codegen + if any(isinstance(getattr(node, "activation_checkpoint", None), list) for node in nodes): emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) + else: + emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body @@ -612,7 +691,6 @@ if CODEGEN_AVAILABLE: # as we need colossalai.utils.checkpoint, we need to import colossalai # in forward function - # TODO: Remove inline import prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) prologue = ''.join(ckpt_func) + prologue prologue = prologue @@ -788,10 +866,13 @@ else: # Modified for activation checkpointing ckpt_func = [] - if all(not isinstance(getattr(node, "activation_checkpoint", None), list) for node in self.nodes): - emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) - else: + + # if any node has a list of labels for activation_checkpoint, we + # will use nested type of activation checkpoint codegen + if any(isinstance(getattr(node, "activation_checkpoint", None), list) for node in self.nodes): emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) + else: + emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body @@ -827,7 +908,6 @@ else: # as we need colossalai.utils.checkpoint, we need to import colossalai # in forward function - # TODO: Remove inline import fn_code = f""" {wrap_stmts} diff --git a/colossalai/fx/graph_module.py b/colossalai/fx/graph_module.py index 78f719852..fbafd326c 100644 --- a/colossalai/fx/graph_module.py +++ b/colossalai/fx/graph_module.py @@ -22,14 +22,20 @@ if COLOGM: super().__init__(root, graph, class_name) def bind(self, ckpt_def, globals): - """Bind checkpoint functions to ColoGraphModule - We need to bind our checkpoint functions to the GraphModule so - that we could correctly use self.checkpoint for GraphModule forward + """Bind function needed for correctly execute gm forward + + We need to bind checkpoint functions and saved_tensor_hooks functions + to gm so that we could correctly execute gm forward + + Args: + ckpt_def (_type_): definition before the forward function + globals (_type_): global variables """ + ckpt_code = "\n".join(ckpt_def) globals_copy = globals.copy() _exec_with_source(ckpt_code, globals_copy) - func_list = [func for func in globals_copy.keys() if "checkpoint" in func] + func_list = [func for func in globals_copy.keys() if "checkpoint" in func or "pack" in func] for func in func_list: tmp_func = globals_copy[func] setattr(self, func, tmp_func.__get__(self, self.__class__)) diff --git a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py index 54a11bb48..08044c687 100644 --- a/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py +++ b/tests/test_fx/test_codegen/test_activation_checkpoint_codegen.py @@ -1,4 +1,3 @@ -from operator import mod import torch import torch.nn.functional as F import pytest diff --git a/tests/test_fx/test_codegen/test_offload_codegen.py b/tests/test_fx/test_codegen/test_offload_codegen.py new file mode 100644 index 000000000..a9d1455b8 --- /dev/null +++ b/tests/test_fx/test_codegen/test_offload_codegen.py @@ -0,0 +1,159 @@ +import copy +import torch +import torch.nn.functional as F +import pytest +import torch.multiprocessing as mp +from torch.fx import GraphModule +from colossalai.fx import ColoTracer +import colossalai +from colossalai.utils import free_port +from colossalai.core import global_context as gpc +from colossalai.fx.graph_module import ColoGraphModule + +try: + from colossalai.fx.codegen import ActivationCheckpointCodeGen + with_codegen = True +except: + # fall back to older pytorch version + from colossalai.fx.codegen import python_code_with_activation_checkpoint + with_codegen = False + + +class MyNet(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + self.linear1 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 4) + self.linear3 = torch.nn.Linear(4, 4) + self.linear4 = torch.nn.Linear(4, 4) + self.linear5 = torch.nn.Linear(4, 4) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + x = self.linear5(x) + return x + + +def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool: + for m_p, gm_p in zip(m.parameters(), gm.parameters()): + if not torch.allclose(m_p.grad, gm_p.grad): + return False + return True + + +def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.Tensor): + + # test forward + non_fx_out = model(data) + fx_out = gm(data) + assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output" + + # test barckward + loss0 = non_fx_out.sum() + loss0.backward() + loss1 = fx_out.sum() + loss1.backward() + assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one" + + +def _run_offload_codegen(rank): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + + # build model and input + model = MyNet().cuda() + data = torch.rand(4, 4).cuda() + + # trace the module and replace codegen + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(model) + codegen = ActivationCheckpointCodeGen() + graph.set_codegen(codegen) + + # annotate the activation offload part + # also annotate the activation_checkpoint so we could test both types + # of input offload + for node in graph.nodes: + if node.name == "linear2": + setattr(node, "activation_offload", True) + if node.name == "linear3": + setattr(node, "activation_offload", True) + setattr(node, "activation_checkpoint", [0]) + if node.name == "linear4": + setattr(node, "activation_checkpoint", [0]) + + gm = ColoGraphModule(copy.deepcopy(model), graph) + gm.recompile() + print(gm) + + # assert we have all the components + code = graph.python_code("self").src + assert "def pack_hook(self, x):" in code and \ + "def unpack_hook(self, packed):" in code and \ + "setattr(linear1, 'offload', True)" in code and \ + "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):" in code and \ + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear2, use_reentrant=False)" in code + + _test_fwd_and_bwd(model, gm, data) + gpc.destroy() + + +@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +def test_act_ckpt_codegen(): + mp.spawn(_run_offload_codegen, nprocs=1) + + +def _run_offload_codegen_torch11(rank): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + + # build model and input + model = MyNet().cuda() + data = torch.rand(4, 4).cuda() + + # trace the module and replace codegen + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(model) + + # replace a bound method of an object + graph._python_code = python_code_with_activation_checkpoint.__get__(graph) + + # annotate the activation offload part + # also annotate the activation_checkpoint so we could test both types + # of input offload + for node in graph.nodes: + if node.name == "linear2": + setattr(node, "activation_offload", True) + if node.name == "linear3": + setattr(node, "activation_offload", True) + setattr(node, "activation_checkpoint", [0]) + if node.name == "linear4": + setattr(node, "activation_checkpoint", [0]) + + gm = ColoGraphModule(copy.deepcopy(model), graph) + gm.recompile() + print(gm) + + # assert we have all the components + code = graph.python_code("self").src + assert "def pack_hook(self, x):" in code and \ + "def unpack_hook(self, packed):" in code and \ + "setattr(linear1, 'offload', True)" in code and \ + "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):" in code and \ + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear2, use_reentrant=False)" in code + + _test_fwd_and_bwd(model, gm, data) + gpc.destroy() + + +@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not implemented") +def test_act_ckpt_python_code_torch11(): + mp.spawn(_run_offload_codegen_torch11, nprocs=1) + + +if __name__ == "__main__": + _run_offload_codegen(0)