From d6b01feb662c5482cacd61d48eee626252619b06 Mon Sep 17 00:00:00 2001 From: Boyuan Yao <70263930+Cypher30@users.noreply.github.com> Date: Fri, 23 Sep 2022 11:04:52 +0800 Subject: [PATCH] [fx] Modify offload codegen (#1618) * [fx] modify offload codegen * [fx] remove repeated hook definitions * [fx] modify offload test --- .../codegen/activation_checkpoint_codegen.py | 198 +++++++++++++++--- .../test_codegen/test_offload_codegen.py | 54 +++-- 2 files changed, 200 insertions(+), 52 deletions(-) diff --git a/colossalai/fx/codegen/activation_checkpoint_codegen.py b/colossalai/fx/codegen/activation_checkpoint_codegen.py index ad143a690..4da4315d4 100644 --- a/colossalai/fx/codegen/activation_checkpoint_codegen.py +++ b/colossalai/fx/codegen/activation_checkpoint_codegen.py @@ -22,8 +22,14 @@ def _gen_saved_tensors_hooks(): Generate saved tensors hooks """ - pack_hook = """def pack_hook(self, x): - if getattr(x, "offload", None): + pack_hook = """def pack_hook_input(self, x): + if getattr(x, "offload", False): + return (x.device, x.cpu()) + else: + return x + +def pack_hook_no_input(self, x): + if getattr(x, "offload", True): return (x.device, x.cpu()) else: return x @@ -40,12 +46,30 @@ def _gen_saved_tensors_hooks(): return pack_hook, unpack_hook -def _gen_save_tensors_hooks_context(): - """ - Generate save tensors hooks context +def _gen_save_tensors_hooks_context(offload_input=True) -> str: + """Generate customized saved_tensors_hooks + + Args: + offload_input (bool, optional): whether we need offload input, if offload_input=False, + we will use self.pack_hook_no_input instead. Defaults to True. + + Returns: + str: generated context """ - context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):\n" + if offload_input: + context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):\n" + else: + context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):\n" + return context + + +def _gen_save_on_cpu_context(): + """ + Generate save on cpu context + """ + + context = "with torch.autograd.graph.save_on_cpu(pin_memory=True):\n" return context @@ -118,6 +142,51 @@ def _find_ckpt_regions(nodes: List[Node]): return ckpt_regions +def _find_offload_regions(nodes: List[Node]): + """This function is to find the offload regions + In pofo algorithm, during annotation, we will annotate the offload region with the + tuple in the form of (idx, offload_input, offload_bar). idx indicates the offload + region's index, offload_input is a bool type indicates whether we need to offload + the input, offload_bar is a bool type indicates whether we need to offload all the + intermediate x_bars of this region. + """ + offload_regions = [] + offload_labels = [] + start = -1 + end = -1 + current_region = None + + for idx, node in enumerate(nodes): + if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', False), tuple): + act_offload_label = node.activation_offload + + if current_region == None: + current_region = act_offload_label + start = idx + offload_labels.append(act_offload_label) + + if act_offload_label != current_region: + assert start != -1 + offload_regions.append((start, idx - 1)) + offload_labels.append(act_offload_label) + current_region = act_offload_label + start = idx + end = -1 + + else: + if current_region is not None: + end = idx - 1 + assert start != -1 and end != -1 + offload_regions.append((start, end)) + start = end = -1 + current_region = None + + else: + pass + + return offload_regions, offload_labels + + def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str: """ Generate the checkpoint function definition @@ -322,8 +391,23 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod start_idx = [item[0] for item in ckpt_regions] end_idx = [item[1] for item in ckpt_regions] + # find the offload regions + offload_regions, offload_labels = _find_offload_regions(nodes) + offload_starts = [item[0] for item in offload_regions] + offload_ends = [item[1] for item in offload_regions] + offload_inputs = [] + offload_outputs = [] + within_offload_region = False + node_list = list(nodes) + # find the input and output var names for each offload region + for idx, (start, end) in enumerate(offload_regions): + offload_node_list = node_list[start:end + 1] + inputs, outputs = _find_input_and_output_nodes(offload_node_list) + offload_inputs.append(inputs) + offload_outputs.append(outputs) + # this flag is to prevent repeated insert of save tensors # hooks definition in ckpt_func is_hook_inserted = False @@ -343,19 +427,31 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod else: node = node_list[node_idx] - # if a node is outside of checkpoint region and want to offload - # it's input activation, we will use torch.saved_tensors_hooks - # to complete the offload process. - if getattr(node, "activation_offload", False): + if node_idx in offload_starts: + offload_label = offload_labels[offload_starts.index(node_idx)] + _, offload_input, offload_bar = offload_label + within_offload_region = True + + # insert hook functions if needed if not is_hook_inserted: pack_hook, unpack_hook = _gen_saved_tensors_hooks() ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n") + is_hook_inserted = True - for par in node.all_input_nodes: - # annotate the input tensor for pack hook - body.append(f"setattr({repr(par)}, 'offload', True)\n") + if offload_input and offload_bar: + body.append(_gen_save_on_cpu_context()) - body.append(_gen_save_tensors_hooks_context()) + elif offload_input: + for par in offload_inputs[offload_label[0]]: + body.append(f"setattr({par}, 'offload', True)\n") + body.append(_gen_save_tensors_hooks_context(offload_input=True)) + + else: + for par in offload_inputs[offload_label[0]]: + body.append(f"setattr({par}, 'offload', False)\n") + body.append(_gen_save_tensors_hooks_context(offload_input=False)) + + if within_offload_region: emit_node_func(node, body) body[-1] = ' ' + body[-1] delete_unused_value_func(node, body) @@ -363,6 +459,10 @@ def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_nod else: emit_node_func(node, body) delete_unused_value_func(node, body) + + if node_idx in offload_ends: + within_offload_region = False + node_idx += 1 @@ -375,6 +475,14 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, output_vars = [] within_ckpt_region = False + # find the offload regions + offload_regions, offload_labels = _find_offload_regions(nodes) + offload_starts = [item[0] for item in offload_regions] + offload_ends = [item[1] for item in offload_regions] + offload_inputs = [] + offload_outputs = [] + within_offload_region = False + node_list = list(nodes) # use this variable to avoid inserting hook functions @@ -388,6 +496,13 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, input_vars.append(inputs) output_vars.append(outputs) + # find the input and output var names for each offload region + for idx, (start, end) in enumerate(offload_regions): + offload_node_list = node_list[start:end + 1] + inputs, outputs = _find_input_and_output_nodes(offload_node_list) + offload_inputs.append(inputs) + offload_outputs.append(outputs) + # append code text to body for idx, node in enumerate(node_list): # if this is the first node of the ckpt region @@ -398,6 +513,30 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, ckpt_func.append(f'{ckpt_fn_def}\n') within_ckpt_region = True + if idx in offload_starts: + offload_label = offload_labels[offload_starts.index(idx)] + _, offload_input, offload_bar = offload_label + within_offload_region = True + + # insert hook functions if needed + if not is_hook_inserted: + pack_hook, unpack_hook = _gen_saved_tensors_hooks() + ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n") + is_hook_inserted = True + + if offload_input and offload_bar: + body.append(_gen_save_on_cpu_context()) + + elif offload_input: + for par in offload_inputs[offload_label[0]]: + body.append(f"setattr({par}, 'offload', True)\n") + body.append(_gen_save_tensors_hooks_context(offload_input=True)) + + else: + for par in offload_inputs[offload_label[0]]: + body.append(f"setattr({par}, 'offload', False)\n") + body.append(_gen_save_tensors_hooks_context(offload_input=False)) + # NOTE: emit_node does not emit a string with newline. It depends # on delete_unused_values to append one # NOTE: currently we separate body and ckpt_func definition @@ -405,27 +544,15 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, emit_node_func(node, ckpt_func) ckpt_func[-1] = ' ' + ckpt_func[-1] delete_unused_value_func(node, ckpt_func) + + elif within_offload_region: + emit_node_func(node, body) + body[-1] = ' ' + body[-1] + delete_unused_value_func(node, body) + else: - # if a node is outside of checkpoint region wants to offload - # it's input activation, we will use torch.saved_tensors_hooks - # to complete the offload process. - if getattr(node, "activation_offload", False): - if not is_hook_inserted: - pack_hook, unpack_hook = _gen_saved_tensors_hooks() - ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n") - - for par in node.all_input_nodes: - # annotate the input tensor for pack hook - body.append(f"setattr({repr(par)}, 'offload', True)\n") - - body.append(_gen_save_tensors_hooks_context()) - emit_node_func(node, body) - body[-1] = ' ' + body[-1] - delete_unused_value_func(node, body) - - else: - emit_node_func(node, body) - delete_unused_value_func(node, body) + emit_node_func(node, body) + delete_unused_value_func(node, body) if idx in end_idx: # if this is the last node of the ckpt region @@ -470,6 +597,9 @@ def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, body.append(usage) within_ckpt_region = False + if idx in offload_ends: + within_offload_region = False + if CODEGEN_AVAILABLE: diff --git a/tests/test_fx/test_codegen/test_offload_codegen.py b/tests/test_fx/test_codegen/test_offload_codegen.py index a9d1455b8..271d86a80 100644 --- a/tests/test_fx/test_codegen/test_offload_codegen.py +++ b/tests/test_fx/test_codegen/test_offload_codegen.py @@ -23,18 +23,22 @@ class MyNet(torch.nn.Module): def __init__(self) -> None: super().__init__() + self.linear0 = torch.nn.Linear(4, 4) self.linear1 = torch.nn.Linear(4, 4) self.linear2 = torch.nn.Linear(4, 4) self.linear3 = torch.nn.Linear(4, 4) self.linear4 = torch.nn.Linear(4, 4) self.linear5 = torch.nn.Linear(4, 4) + self.linear6 = torch.nn.Linear(4, 4) def forward(self, x): + x = self.linear0(x) x = self.linear1(x) x = self.linear2(x) x = self.linear3(x) x = self.linear4(x) x = self.linear5(x) + x = self.linear6(x) return x @@ -78,25 +82,32 @@ def _run_offload_codegen(rank): # also annotate the activation_checkpoint so we could test both types # of input offload for node in graph.nodes: + if node.name == "linear0": + setattr(node, "activation_offload", (0, True, False)) + if node.name == "linear1": + setattr(node, "activation_offload", (0, True, False)) if node.name == "linear2": - setattr(node, "activation_offload", True) - if node.name == "linear3": - setattr(node, "activation_offload", True) - setattr(node, "activation_checkpoint", [0]) + setattr(node, "activation_offload", (1, True, True)) if node.name == "linear4": + setattr(node, "activation_offload", (2, False, True)) + if node.name == "linear5": setattr(node, "activation_checkpoint", [0]) + setattr(node, "activation_offload", True) gm = ColoGraphModule(copy.deepcopy(model), graph) gm.recompile() - print(gm) # assert we have all the components code = graph.python_code("self").src - assert "def pack_hook(self, x):" in code and \ + assert "def pack_hook_input(self, x):" in code and \ "def unpack_hook(self, packed):" in code and \ - "setattr(linear1, 'offload', True)" in code and \ - "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):" in code and \ - "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear2, use_reentrant=False)" in code + "def pack_hook_no_input(self, x):" in code and \ + "setattr(x, 'offload', True)" in code and \ + "setattr(linear3, 'offload', False)" in code and \ + "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \ + "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \ + "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \ + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code _test_fwd_and_bwd(model, gm, data) gpc.destroy() @@ -126,25 +137,32 @@ def _run_offload_codegen_torch11(rank): # also annotate the activation_checkpoint so we could test both types # of input offload for node in graph.nodes: + if node.name == "linear0": + setattr(node, "activation_offload", (0, True, False)) + if node.name == "linear1": + setattr(node, "activation_offload", (0, True, False)) if node.name == "linear2": - setattr(node, "activation_offload", True) - if node.name == "linear3": - setattr(node, "activation_offload", True) - setattr(node, "activation_checkpoint", [0]) + setattr(node, "activation_offload", (1, True, True)) if node.name == "linear4": + setattr(node, "activation_offload", (2, False, True)) + if node.name == "linear5": setattr(node, "activation_checkpoint", [0]) + setattr(node, "activation_offload", True) gm = ColoGraphModule(copy.deepcopy(model), graph) gm.recompile() - print(gm) # assert we have all the components code = graph.python_code("self").src - assert "def pack_hook(self, x):" in code and \ + assert "def pack_hook_input(self, x):" in code and \ "def unpack_hook(self, packed):" in code and \ - "setattr(linear1, 'offload', True)" in code and \ - "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook, self.unpack_hook):" in code and \ - "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear2, use_reentrant=False)" in code + "def pack_hook_no_input(self, x):" in code and \ + "setattr(x, 'offload', True)" in code and \ + "setattr(linear3, 'offload', False)" in code and \ + "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \ + "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \ + "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \ + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code _test_fwd_and_bwd(model, gm, data) gpc.destroy()