From d6b01feb662c5482cacd61d48eee626252619b06 Mon Sep 17 00:00:00 2001
From: Boyuan Yao <70263930+Cypher30@users.noreply.github.com>
Date: Fri, 23 Sep 2022 11:04:52 +0800
Subject: [PATCH] [fx] Modify offload codegen (#1618)

* [fx] modify offload codegen

* [fx] remove repeated hook definitions

* [fx] modify offload test
---
 .../codegen/activation_checkpoint_codegen.py  | 198 +++++++++++++++---
 .../test_codegen/test_offload_codegen.py      |  54 +++--
 2 files changed, 200 insertions(+), 52 deletions(-)

diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py
index ad143a690..4da4315d4 100644
--- a/colossalai/fx/codegen/activation_checkpoint_codegen.py
+++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py
@@ -22,8 +22,14 @@ def _gen_saved_tensors_hooks():
     Generate saved tensors hooks
     """
 
-    pack_hook = """def pack_hook(self, x):
-    if getattr(x, "offload", None):
+    pack_hook = """def pack_hook_input(self, x):
+    if getattr(x, "offload", False):
+        return (x.device, x.cpu())
+    else:
+        return x
+ 
+def pack_hook_no_input(self, x):
+    if getattr(x, "offload", True):
         return (x.device, x.cpu())
     else:
         return x
@@ -40,12 +46,30 @@ def _gen_saved_tensors_hooks():
     return pack_hook, unpack_hook
 
 
-def _gen_save_tensors_hooks_context():
-    """
-    Generate save tensors hooks context
+def _gen_save_tensors_hooks_context(offload_input=True) -> str:
+    """Generate customized saved_tensors_hooks
+
+    Args:
+        offload_input (bool, optional): whether we need offload input, if offload_input=False, 
+        we will use self.pack_hook_no_input instead. Defaults to True.
+
+    Returns:
+        str: generated context
     """
 
-    context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):\n"
+    if offload_input:
+        context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):\n"
+    else:
+        context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):\n"
+    return context
+
+
+def _gen_save_on_cpu_context():
+    """
+    Generate save on cpu context
+    """
+
+    context = "with torch.autograd.graph.save_on_cpu(pin_memory=True):\n"
     return context
 
 
@@ -118,6 +142,51 @@ def _find_ckpt_regions(nodes: List[Node]):
     return ckpt_regions
 
 
+def _find_offload_regions(nodes: List[Node]):
+    """This function is to find the offload regions
+    In pofo algorithm, during annotation, we will annotate the offload region with the 
+    tuple in the form of (idx, offload_input, offload_bar). idx indicates the offload
+    region's index, offload_input is a bool type indicates whether we need to offload
+    the input, offload_bar is a bool type indicates whether we need to offload all the
+    intermediate x_bars of this region.
+    """
+    offload_regions = []
+    offload_labels = []
+    start = -1
+    end = -1
+    current_region = None
+
+    for idx, node in enumerate(nodes):
+        if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', False), tuple):
+            act_offload_label = node.activation_offload
+
+            if current_region == None:
+                current_region = act_offload_label
+                start = idx
+                offload_labels.append(act_offload_label)
+
+            if act_offload_label != current_region:
+                assert start != -1
+                offload_regions.append((start, idx - 1))
+                offload_labels.append(act_offload_label)
+                current_region = act_offload_label
+                start = idx
+                end = -1
+
+        else:
+            if current_region is not None:
+                end = idx - 1
+                assert start != -1 and end != -1
+                offload_regions.append((start, end))
+                start = end = -1
+                current_region = None
+
+            else:
+                pass
+
+    return offload_regions, offload_labels
+
+
 def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str:
     """
     Generate the checkpoint function definition
@@ -322,8 +391,23 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
     start_idx = [item[0] for item in ckpt_regions]
     end_idx = [item[1] for item in ckpt_regions]
 
+    # find the offload regions
+    offload_regions, offload_labels = _find_offload_regions(nodes)
+    offload_starts = [item[0] for item in offload_regions]
+    offload_ends = [item[1] for item in offload_regions]
+    offload_inputs = []
+    offload_outputs = []
+    within_offload_region = False
+
     node_list = list(nodes)
 
+    # find the input and output var names for each offload region
+    for idx, (start, end) in enumerate(offload_regions):
+        offload_node_list = node_list[start:end + 1]
+        inputs, outputs = _find_input_and_output_nodes(offload_node_list)
+        offload_inputs.append(inputs)
+        offload_outputs.append(outputs)
+
     # this flag is to prevent repeated insert of save tensors
     # hooks definition in ckpt_func
     is_hook_inserted = False
@@ -343,19 +427,31 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
         else:
             node = node_list[node_idx]
 
-            # if a node is outside of checkpoint region and want to offload
-            # it's input activation, we will use torch.saved_tensors_hooks
-            # to complete the offload process.
-            if getattr(node, "activation_offload", False):
+            if node_idx in offload_starts:
+                offload_label = offload_labels[offload_starts.index(node_idx)]
+                _, offload_input, offload_bar = offload_label
+                within_offload_region = True
+
+                # insert hook functions if needed
                 if not is_hook_inserted:
                     pack_hook, unpack_hook = _gen_saved_tensors_hooks()
                     ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n")
+                    is_hook_inserted = True
 
-                for par in node.all_input_nodes:
-                    # annotate the input tensor for pack hook
-                    body.append(f"setattr({repr(par)}, 'offload', True)\n")
+                if offload_input and offload_bar:
+                    body.append(_gen_save_on_cpu_context())
 
-                body.append(_gen_save_tensors_hooks_context())
+                elif offload_input:
+                    for par in offload_inputs[offload_label[0]]:
+                        body.append(f"setattr({par}, 'offload', True)\n")
+                    body.append(_gen_save_tensors_hooks_context(offload_input=True))
+
+                else:
+                    for par in offload_inputs[offload_label[0]]:
+                        body.append(f"setattr({par}, 'offload', False)\n")
+                    body.append(_gen_save_tensors_hooks_context(offload_input=False))
+
+            if within_offload_region:
                 emit_node_func(node, body)
                 body[-1] = '    ' + body[-1]
                 delete_unused_value_func(node, body)
@@ -363,6 +459,10 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod
             else:
                 emit_node_func(node, body)
                 delete_unused_value_func(node, body)
+
+            if node_idx in offload_ends:
+                within_offload_region = False
+
             node_idx += 1
 
 
@@ -375,6 +475,14 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
     output_vars = []
     within_ckpt_region = False
 
+    # find the offload regions
+    offload_regions, offload_labels = _find_offload_regions(nodes)
+    offload_starts = [item[0] for item in offload_regions]
+    offload_ends = [item[1] for item in offload_regions]
+    offload_inputs = []
+    offload_outputs = []
+    within_offload_region = False
+
     node_list = list(nodes)
 
     # use this variable to avoid inserting hook functions
@@ -388,6 +496,13 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
         input_vars.append(inputs)
         output_vars.append(outputs)
 
+    # find the input and output var names for each offload region
+    for idx, (start, end) in enumerate(offload_regions):
+        offload_node_list = node_list[start:end + 1]
+        inputs, outputs = _find_input_and_output_nodes(offload_node_list)
+        offload_inputs.append(inputs)
+        offload_outputs.append(outputs)
+
     # append code text to body
     for idx, node in enumerate(node_list):
         # if this is the first node of the ckpt region
@@ -398,6 +513,30 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
             ckpt_func.append(f'{ckpt_fn_def}\n')
             within_ckpt_region = True
 
+        if idx in offload_starts:
+            offload_label = offload_labels[offload_starts.index(idx)]
+            _, offload_input, offload_bar = offload_label
+            within_offload_region = True
+
+            # insert hook functions if needed
+            if not is_hook_inserted:
+                pack_hook, unpack_hook = _gen_saved_tensors_hooks()
+                ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n")
+                is_hook_inserted = True
+
+            if offload_input and offload_bar:
+                body.append(_gen_save_on_cpu_context())
+
+            elif offload_input:
+                for par in offload_inputs[offload_label[0]]:
+                    body.append(f"setattr({par}, 'offload', True)\n")
+                body.append(_gen_save_tensors_hooks_context(offload_input=True))
+
+            else:
+                for par in offload_inputs[offload_label[0]]:
+                    body.append(f"setattr({par}, 'offload', False)\n")
+                body.append(_gen_save_tensors_hooks_context(offload_input=False))
+
         # NOTE: emit_node does not emit a string with newline. It depends
         # on delete_unused_values to append one
         # NOTE: currently we separate body and ckpt_func definition
@@ -405,27 +544,15 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
             emit_node_func(node, ckpt_func)
             ckpt_func[-1] = '    ' + ckpt_func[-1]
             delete_unused_value_func(node, ckpt_func)
+
+        elif within_offload_region:
+            emit_node_func(node, body)
+            body[-1] = '    ' + body[-1]
+            delete_unused_value_func(node, body)
+
         else:
-            # if a node is outside of checkpoint region wants to offload
-            # it's input activation, we will use torch.saved_tensors_hooks
-            # to complete the offload process.
-            if getattr(node, "activation_offload", False):
-                if not is_hook_inserted:
-                    pack_hook, unpack_hook = _gen_saved_tensors_hooks()
-                    ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n")
-
-                for par in node.all_input_nodes:
-                    # annotate the input tensor for pack hook
-                    body.append(f"setattr({repr(par)}, 'offload', True)\n")
-
-                body.append(_gen_save_tensors_hooks_context())
-                emit_node_func(node, body)
-                body[-1] = '    ' + body[-1]
-                delete_unused_value_func(node, body)
-
-            else:
-                emit_node_func(node, body)
-                delete_unused_value_func(node, body)
+            emit_node_func(node, body)
+            delete_unused_value_func(node, body)
 
         if idx in end_idx:
             # if this is the last node of the ckpt region
@@ -470,6 +597,9 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func,
             body.append(usage)
             within_ckpt_region = False
 
+        if idx in offload_ends:
+            within_offload_region = False
+
 
 if CODEGEN_AVAILABLE:
 
diff --git a/tests/test_fx/test_codegen/test_offload_codegen.py b/tests/test_fx/test_codegen/test_offload_codegen.py
index a9d1455b8..271d86a80 100644
--- a/tests/test_fx/test_codegen/test_offload_codegen.py
+++ b/tests/test_fx/test_codegen/test_offload_codegen.py
@@ -23,18 +23,22 @@ class MyNet(torch.nn.Module):
 
     def __init__(self) -> None:
         super().__init__()
+        self.linear0 = torch.nn.Linear(4, 4)
         self.linear1 = torch.nn.Linear(4, 4)
         self.linear2 = torch.nn.Linear(4, 4)
         self.linear3 = torch.nn.Linear(4, 4)
         self.linear4 = torch.nn.Linear(4, 4)
         self.linear5 = torch.nn.Linear(4, 4)
+        self.linear6 = torch.nn.Linear(4, 4)
 
     def forward(self, x):
+        x = self.linear0(x)
         x = self.linear1(x)
         x = self.linear2(x)
         x = self.linear3(x)
         x = self.linear4(x)
         x = self.linear5(x)
+        x = self.linear6(x)
         return x
 
 
@@ -78,25 +82,32 @@ def _run_offload_codegen(rank):
     # also annotate the activation_checkpoint so we could test both types
     # of input offload
     for node in graph.nodes:
+        if node.name == "linear0":
+            setattr(node, "activation_offload", (0, True, False))
+        if node.name == "linear1":
+            setattr(node, "activation_offload", (0, True, False))
         if node.name == "linear2":
-            setattr(node, "activation_offload", True)
-        if node.name == "linear3":
-            setattr(node, "activation_offload", True)
-            setattr(node, "activation_checkpoint", [0])
+            setattr(node, "activation_offload", (1, True, True))
         if node.name == "linear4":
+            setattr(node, "activation_offload", (2, False, True))
+        if node.name == "linear5":
             setattr(node, "activation_checkpoint", [0])
+            setattr(node, "activation_offload", True)
 
     gm = ColoGraphModule(copy.deepcopy(model), graph)
     gm.recompile()
-    print(gm)
 
     # assert we have all the components
     code = graph.python_code("self").src
-    assert "def pack_hook(self, x):" in code and \
+    assert "def pack_hook_input(self, x):" in code and \
     "def unpack_hook(self, packed):" in code and \
-    "setattr(linear1, 'offload', True)" in code and \
-    "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):" in code and \
-    "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear2, use_reentrant=False)" in code
+    "def pack_hook_no_input(self, x):" in code and \
+    "setattr(x, 'offload', True)" in code and \
+    "setattr(linear3, 'offload', False)" in code and \
+    "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \
+    "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \
+    "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \
+    "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code
 
     _test_fwd_and_bwd(model, gm, data)
     gpc.destroy()
@@ -126,25 +137,32 @@ def _run_offload_codegen_torch11(rank):
     # also annotate the activation_checkpoint so we could test both types
     # of input offload
     for node in graph.nodes:
+        if node.name == "linear0":
+            setattr(node, "activation_offload", (0, True, False))
+        if node.name == "linear1":
+            setattr(node, "activation_offload", (0, True, False))
         if node.name == "linear2":
-            setattr(node, "activation_offload", True)
-        if node.name == "linear3":
-            setattr(node, "activation_offload", True)
-            setattr(node, "activation_checkpoint", [0])
+            setattr(node, "activation_offload", (1, True, True))
         if node.name == "linear4":
+            setattr(node, "activation_offload", (2, False, True))
+        if node.name == "linear5":
             setattr(node, "activation_checkpoint", [0])
+            setattr(node, "activation_offload", True)
 
     gm = ColoGraphModule(copy.deepcopy(model), graph)
     gm.recompile()
-    print(gm)
 
     # assert we have all the components
     code = graph.python_code("self").src
-    assert "def pack_hook(self, x):" in code and \
+    assert "def pack_hook_input(self, x):" in code and \
     "def unpack_hook(self, packed):" in code and \
-    "setattr(linear1, 'offload', True)" in code and \
-    "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):" in code and \
-    "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear2, use_reentrant=False)" in code
+    "def pack_hook_no_input(self, x):" in code and \
+    "setattr(x, 'offload', True)" in code and \
+    "setattr(linear3, 'offload', False)" in code and \
+    "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \
+    "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \
+    "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \
+    "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code
 
     _test_fwd_and_bwd(model, gm, data)
     gpc.destroy()