diff --git a/chunk_codegen.py b/chunk_codegen.py new file mode 100644 index 000000000..684028c01 --- /dev/null +++ b/chunk_codegen.py @@ -0,0 +1,1047 @@ +import colossalai +import torch +from typing import List, Callable, Any, Tuple, Dict, Iterable + +try: + from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name + from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, CodeGen, _origin_type_map, inplace_methods, _CustomBuiltin + CODEGEN_AVAILABLE = True +except: + from torch.fx.graph import _Namespace, PythonCode, _custom_builtins, _is_from_torch, _format_target, magic_methods, _origin_type_map, _format_args, _CustomBuiltin + from torch.fx.node import Node, Argument, map_arg, _type_repr, _get_qualified_name + CODEGEN_AVAILABLE = False + +if CODEGEN_AVAILABLE: + __all__ = ['ActivationCheckpointCodeGen'] +else: + __all__ = ['python_code_with_activation_checkpoint'] + + +def _gen_saved_tensors_hooks(): + """ + Generate saved tensors hooks + """ + + pack_hook = """def pack_hook_input(self, x): + if getattr(x, "offload", False): + return (x.device, x.cpu()) + else: + return x + +def pack_hook_no_input(self, x): + if getattr(x, "offload", True): + return (x.device, x.cpu()) + else: + return x +""" + + unpack_hook = """def unpack_hook(self, packed): + if isinstance(packed, tuple): + device, tensor = packed + return tensor.to(device) + else: + return packed +""" + + return pack_hook, unpack_hook + + +def _gen_save_tensors_hooks_context(offload_input=True) -> str: + """Generate customized saved_tensors_hooks + + Args: + offload_input (bool, optional): whether we need offload input, if offload_input=False, + we will use self.pack_hook_no_input instead. Defaults to True. + + Returns: + str: generated context + """ + + if offload_input: + context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):\n" + else: + context = "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):\n" + return context + + +def _gen_save_on_cpu_context(): + """ + Generate save on cpu context + """ + + context = "with torch.autograd.graph.save_on_cpu(pin_memory=True):\n" + return context + + +def _find_input_and_output_nodes(nodes: List[Node]): + """ + Find the input and output node names which are not found in the given list of nodes. + """ + input_nodes = [] + output_nodes = [] + + # if a node has an input node which is not in the node list + # we treat that input node as the input of the checkpoint function + for node in nodes: + for input_node in node._input_nodes.keys(): + node_repr = repr(input_node) + if input_node not in nodes and node_repr not in input_nodes: + input_nodes.append(node_repr) + + # if a node has a user node which is not in the node list + # we treat that user node as the node receiving the current node output + for node in nodes: + for output_node in node.users.keys(): + node_repr = repr(node) + if output_node not in nodes and node_repr not in output_nodes: + output_nodes.append(node_repr) + + return input_nodes, output_nodes + + +def _find_ckpt_regions(nodes: List[Node]): + """ + Find the checkpoint regions given a list of consecutive nodes. The outputs will be list + of tuples, each tuple is in the form of (start_index, end_index). + """ + ckpt_nodes = [] + ckpt_regions = [] + start = -1 + end = -1 + current_region = None + + for idx, node in enumerate(nodes): + if hasattr(node, 'activation_checkpoint'): + act_ckpt_label = node.activation_checkpoint + + # this activation checkpoint label is not set yet + # meaning this is the first node of the activation ckpt region + if current_region is None: + current_region = act_ckpt_label + start = idx + + # if activation checkpoint has changed + # we restart the tracking + # e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2] + if act_ckpt_label != current_region: + assert start != -1 + ckpt_regions.append((start, idx - 1)) + current_region = act_ckpt_label + start = idx + end = -1 + elif current_region is not None and not hasattr(node, 'activation_checkpoint'): + # used to check the case below + # node ckpt states = [ckpt, ckpt, non-ckpt] + end = idx - 1 + assert start != -1 and end != -1 + ckpt_regions.append((start, end)) + start = end = -1 + current_region = None + else: + pass + return ckpt_regions + + +def _find_offload_regions(nodes: List[Node]): + """This function is to find the offload regions + In pofo algorithm, during annotation, we will annotate the offload region with the + list in the form of [idx, offload_input, offload_bar]. idx indicates the offload + region's index, offload_input is a bool type indicates whether we need to offload + the input, offload_bar is a bool type indicates whether we need to offload all the + intermediate x_bars of this region. + """ + offload_regions = [] + offload_labels = [] + start = -1 + end = -1 + current_region = None + + for idx, node in enumerate(nodes): + if hasattr(node, 'activation_offload') and isinstance(getattr(node, 'activation_offload', None), Iterable): + act_offload_label = node.activation_offload + + if current_region == None: + current_region = act_offload_label + start = idx + offload_labels.append(act_offload_label) + + if act_offload_label != current_region: + assert start != -1 + offload_regions.append((start, idx - 1)) + offload_labels.append(act_offload_label) + current_region = act_offload_label + start = idx + end = -1 + + else: + if current_region is not None: + end = idx - 1 + assert start != -1 and end != -1 + offload_regions.append((start, end)) + start = end = -1 + current_region = None + + else: + pass + + return offload_regions, offload_labels + + +def _gen_ckpt_fn_def(label, free_vars: List[str]) -> str: + """ + Generate the checkpoint function definition + """ + return f"def checkpoint_{label}({', '.join(['self'] + free_vars)}):" + + +def _gen_ckpt_output(output_vars: List[str]) -> str: + """ + Generate the return statement for checkpoint region + """ + return f"return {', '.join(output_vars)}" + + +def _gen_ckpt_usage(label, activation_offload, input_vars, output_vars, use_reentrant=True): + """ + Generate the checkpoint function call code text + """ + outputs = ', '.join(output_vars) + inputs = ', '.join(input_vars) + return f'{outputs} = colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_{label}, {activation_offload}, {inputs}, use_reentrant={use_reentrant})' + + +def _end_of_ckpt(node: Node, check_idx: int) -> bool: + """Check if the node could end the ckpt region + + Args: + node (Node): torch.fx.Node + check_idx (int): the index of checkpoint level for + nested checkpoint + + Returns: + bool + """ + if hasattr(node, "activation_checkpoint"): + if isinstance(node.activation_checkpoint, list): + return node.activation_checkpoint[check_idx] == None + else: + return False + else: + return True + + +def _find_nested_ckpt_regions(nodes, check_idx=0): + """ + Find the nested checkpoint regions given a list of consecutive nodes. The outputs + will be list of tuples, each tuple is in the form of (start_index, end_index). + """ + ckpt_regions = [] + start = -1 + end = -1 + current_region = None + + for idx, node in enumerate(nodes): + if hasattr(node, 'activation_checkpoint'): + if isinstance(getattr(node, 'activation_checkpoint'), int): + act_ckpt_label = node.activation_checkpoint + else: + act_ckpt_label = node.activation_checkpoint[check_idx] + + # this activation checkpoint label is not set yet + # meaning this is the first node of the activation ckpt region + if current_region is None: + current_region = act_ckpt_label + start = idx + + # if activation checkpoint has changed + # we restart the tracking + # e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2] + if act_ckpt_label != current_region: + assert start != -1 + ckpt_regions.append((start, idx - 1)) + current_region = act_ckpt_label + start = idx + end = -1 + elif current_region is not None and _end_of_ckpt(node, check_idx): + # used to check the case below + # node ckpt states = [ckpt, ckpt, non-ckpt] + end = idx - 1 + assert start != -1 and end != -1 + ckpt_regions.append((start, end)) + start = end = -1 + current_region = None + else: + pass + + if current_region is not None: + end = len(nodes) - 1 + ckpt_regions.append((start, end)) + return ckpt_regions + + +def emit_ckpt_func(body, + ckpt_func, + node_list: List[Node], + emit_node_func, + delete_unused_value_func, + level=0, + in_ckpt=False): + """Emit ckpt fuction in nested way + + Args: + body: forward code, in recursive calls, this part will be checkpoint + functions code + ckpt_func: checkpoint functions code, in recursive calls, this part + will be a buffer + node_list (List[Node]): list of torch.fx.Node + emit_node_func: function to emit a node + delete_unused_value_func: function to delete unused value + level (int, optional): checkpoint level. Defaults to 0. + in_ckpt (bool, optional): indicates wether the func is in recursive + call. Defaults to False. + """ + inputs, outputs = _find_input_and_output_nodes(node_list) + + # if the current checkpoint function use int as label, using old generation method + if isinstance(node_list[0].activation_checkpoint, int): + label = node_list[0].activation_checkpoint + ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) + ckpt_func.append(f'{ckpt_fn_def}\n') + for node in node_list: + emit_node_func(node, ckpt_func) + ckpt_func[-1] = ' ' + ckpt_func[-1] + delete_unused_value_func(node, ckpt_func) + + ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + activation_offload = getattr(node_list[0], "activation_offload", False) + usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + usage += "\n" + body.append(usage) + + # use nested ckpt function codegen + else: + # label given by each layer, e.g. if you are currently at level [0, 1, 1] + # the label will be '0_1_1' + label = "_".join([str(idx) for idx in node_list[0].activation_checkpoint[:level + 1]]) + ckpt_fn_def = _gen_ckpt_fn_def(label, inputs) + ckpt_func.append(f'{ckpt_fn_def}\n') + + # if there is more level to fetch + if level + 1 < len(node_list[0].activation_checkpoint): + ckpt_regions = _find_nested_ckpt_regions(node_list, level + 1) + start_idx = [item[0] for item in ckpt_regions] + end_idx = [item[1] for item in ckpt_regions] + + # use ckpt_func_buffer to store nested checkpoint functions + ckpt_func_buffer = [] + node_idx = 0 + while 1: + if node_idx >= len(node_list): + break + + if node_idx in start_idx: + ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] + emit_ckpt_func(ckpt_func, ckpt_func_buffer, ckpt_node_list, emit_node_func, + delete_unused_value_func, level + 1, True) + node_idx += len(ckpt_node_list) + + else: + node = node_list[node_idx] + emit_node_func(node, ckpt_func) + ckpt_func[-1] = ' ' + ckpt_func[-1] + delete_unused_value_func(node, ckpt_func) + node_idx += 1 + + ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + ckpt_func += ckpt_func_buffer + activation_offload = getattr(node_list[0], "activation_offload", False) + usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' + if in_ckpt: + usage = ' ' + usage + body.append(usage) + + # last level + else: + for node in node_list: + emit_node_func(node, ckpt_func) + ckpt_func[-1] = ' ' + ckpt_func[-1] + delete_unused_value_func(node, ckpt_func) + + ckpt_func.append(' ' + _gen_ckpt_output(outputs) + '\n\n') + activation_offload = getattr(node_list[0], "activation_offload", False) + usage = _gen_ckpt_usage(label, activation_offload, inputs, outputs, False) + '\n' + if in_ckpt: + usage = ' ' + usage + body.append(usage) + + +def emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func): + """Emit code with nested activation checkpoint + When we detect some of the node.activation_checkpoint is a List, we will use + this function to emit the activation checkpoint codes. + + Args: + body: forward code + ckpt_func: checkpoint functions code + nodes: graph.nodes + emit_node_func: function to emit node + delete_unused_value_func: function to remove the unused value + """ + ckpt_regions = _find_nested_ckpt_regions(nodes, 0) + start_idx = [item[0] for item in ckpt_regions] + end_idx = [item[1] for item in ckpt_regions] + + # find the offload regions + offload_regions, offload_labels = _find_offload_regions(nodes) + offload_starts = [item[0] for item in offload_regions] + offload_ends = [item[1] for item in offload_regions] + offload_inputs = [] + offload_outputs = [] + within_offload_region = False + + node_list = list(nodes) + + # find the input and output var names for each offload region + for idx, (start, end) in enumerate(offload_regions): + offload_node_list = node_list[start:end + 1] + inputs, outputs = _find_input_and_output_nodes(offload_node_list) + offload_inputs.append(inputs) + offload_outputs.append(outputs) + + # this flag is to prevent repeated insert of save tensors + # hooks definition in ckpt_func + is_hook_inserted = False + node_idx = 0 + while 1: + # break if we finish the processing all the nodes + if node_idx >= len(node_list): + break + + # process ckpt_regions + if node_idx in start_idx: + ckpt_node_list = node_list[node_idx:end_idx[start_idx.index(node_idx)] + 1] + emit_ckpt_func(body, ckpt_func, ckpt_node_list, emit_node_func, delete_unused_value_func) + node_idx += len(ckpt_node_list) + + # process node in forward function + else: + node = node_list[node_idx] + + if node_idx in offload_starts: + offload_label = offload_labels[offload_starts.index(node_idx)] + _, offload_input, offload_bar = offload_label + within_offload_region = True + + # insert hook functions if needed + if not is_hook_inserted: + pack_hook, unpack_hook = _gen_saved_tensors_hooks() + ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n") + is_hook_inserted = True + + if offload_input and offload_bar: + body.append(_gen_save_on_cpu_context()) + + elif offload_input: + for par in offload_inputs[offload_label[0]]: + body.append(f"setattr({par}, 'offload', True)\n") + body.append(_gen_save_tensors_hooks_context(offload_input=True)) + + else: + for par in offload_inputs[offload_label[0]]: + body.append(f"setattr({par}, 'offload', False)\n") + body.append(_gen_save_tensors_hooks_context(offload_input=False)) + + if within_offload_region: + emit_node_func(node, body) + body[-1] = ' ' + body[-1] + delete_unused_value_func(node, body) + + else: + emit_node_func(node, body) + delete_unused_value_func(node, body) + + if node_idx in offload_ends: + within_offload_region = False + + node_idx += 1 + + +def emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node_func, delete_unused_value_func): + # find the activation checkpoint regions + ckpt_regions = _find_ckpt_regions(nodes) + start_idx = [item[0] for item in ckpt_regions] + end_idx = [item[1] for item in ckpt_regions] + input_vars = [] + output_vars = [] + within_ckpt_region = False + + # find the offload regions + offload_regions, offload_labels = _find_offload_regions(nodes) + offload_starts = [item[0] for item in offload_regions] + offload_ends = [item[1] for item in offload_regions] + offload_inputs = [] + offload_outputs = [] + within_offload_region = False + + node_list = list(nodes) + + # use this variable to avoid inserting hook functions + # to ckpt_func repeatedly + is_hook_inserted = False + + # find the input and output var names for each region + for idx, (start, end) in enumerate(ckpt_regions): + ckpt_node_list = node_list[start:end + 1] + inputs, outputs = _find_input_and_output_nodes(ckpt_node_list) + input_vars.append(inputs) + output_vars.append(outputs) + + # find the input and output var names for each offload region + for idx, (start, end) in enumerate(offload_regions): + offload_node_list = node_list[start:end + 1] + inputs, outputs = _find_input_and_output_nodes(offload_node_list) + offload_inputs.append(inputs) + offload_outputs.append(outputs) + + # append code text to body + for idx, node in enumerate(node_list): + # if this is the first node of the ckpt region + # append the ckpt function defition + if idx in start_idx: + label = start_idx.index(idx) + ckpt_fn_def = _gen_ckpt_fn_def(label, input_vars[label]) + ckpt_func.append(f'{ckpt_fn_def}\n') + within_ckpt_region = True + + if idx in offload_starts: + offload_label = offload_labels[offload_starts.index(idx)] + _, offload_input, offload_bar = offload_label + within_offload_region = True + + # insert hook functions if needed + if not is_hook_inserted: + pack_hook, unpack_hook = _gen_saved_tensors_hooks() + ckpt_func.insert(0, "\n".join([pack_hook, unpack_hook]) + "\n") + is_hook_inserted = True + + if offload_input and offload_bar: + body.append(_gen_save_on_cpu_context()) + + elif offload_input: + for par in offload_inputs[offload_label[0]]: + body.append(f"setattr({par}, 'offload', True)\n") + body.append(_gen_save_tensors_hooks_context(offload_input=True)) + + else: + for par in offload_inputs[offload_label[0]]: + body.append(f"setattr({par}, 'offload', False)\n") + body.append(_gen_save_tensors_hooks_context(offload_input=False)) + + # NOTE: emit_node does not emit a string with newline. It depends + # on delete_unused_values to append one + # NOTE: currently we separate body and ckpt_func definition + if within_ckpt_region: + emit_node_func(node, ckpt_func) + ckpt_func[-1] = ' ' + ckpt_func[-1] + delete_unused_value_func(node, ckpt_func) + + elif within_offload_region: + emit_node_func(node, body) + body[-1] = ' ' + body[-1] + delete_unused_value_func(node, body) + + else: + emit_node_func(node, body) + delete_unused_value_func(node, body) + + if idx in end_idx: + # if this is the last node of the ckpt region + # generate return statement + label = end_idx.index(idx) + return_statement = _gen_ckpt_output(output_vars[label]) + return_statement = f' {return_statement}\n\n' + ckpt_func.append(return_statement) + + # we need to check if the checkpoint need to offload the input + start_node_idx = start_idx[label] + if hasattr(node_list[start_node_idx], 'activation_offload'): + activation_offload = node_list[start_node_idx].activation_offload + else: + activation_offload = False + + # we need to check if the checkpoint need use_reentrant=False + use_reentrant = True + non_leaf_input = 0 + for var in input_vars[label]: + input_node = next(item for item in node_list if item.name == var) + if input_node.op != "placeholder": + non_leaf_input = 1 + for user in input_node.users: + if hasattr(user, "activation_checkpoint"): + if user.activation_checkpoint == label: + if user.op == "call_module": + if hasattr(user.graph.owning_module.get_submodule(user.target), "inplace"): + use_reentrant = not user.graph.owning_module.get_submodule(user.target).inplace + + elif user.op == "call_function": + if "inplace" in user.kwargs: + use_reentrant = not user.kwargs["inplace"] + + # if all the inputs are leaf nodes, we need to set use_reentrant = False + if not non_leaf_input: + use_reentrant = False + + # generate checkpoint function call in a new line + usage = _gen_ckpt_usage(label, activation_offload, input_vars[label], output_vars[label], use_reentrant) + usage += '\n' + body.append(usage) + within_ckpt_region = False + + if idx in offload_ends: + within_offload_region = False + + +if CODEGEN_AVAILABLE: + + class ActivationCheckpointCodeGen(CodeGen): + + def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode: + free_vars: List[str] = [] + body: List[str] = [] + globals_: Dict[str, Any] = {} + wrapped_fns: Dict[str, None] = {} + + # Wrap string in list to pass by reference + maybe_return_annotation: List[str] = [''] + + def add_global(name_hint: str, obj: Any): + """Add an obj to be tracked as a global. + + We call this for names that reference objects external to the + Graph, like functions or types. + + Returns: the global name that should be used to reference 'obj' in generated source. + """ + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + # HACK: workaround for how torch custom ops are registered. We + # can't import them like normal modules so they must retain their + # fully qualified name. + return _get_qualified_name(obj) + + # normalize the name hint to get a proper identifier + global_name = namespace.create_name(name_hint, obj) + + if global_name in globals_: + assert globals_[global_name] is obj + return global_name + globals_[global_name] = obj + return global_name + + # set _custom_builtins here so that we needn't import colossalai in forward + _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) + + # Pre-fill the globals table with registered builtins. + for name, (_, obj) in _custom_builtins.items(): + add_global(name, obj) + + def type_repr(o: Any): + if o == (): + # Empty tuple is used for empty tuple type annotation Tuple[()] + return '()' + + typename = _type_repr(o) + + if hasattr(o, '__origin__'): + # This is a generic type, e.g. typing.List[torch.Tensor] + origin_type = _origin_type_map.get(o.__origin__, o.__origin__) + origin_typename = add_global(_type_repr(origin_type), origin_type) + + if hasattr(o, '__args__'): + # Assign global names for each of the inner type variables. + args = [type_repr(arg) for arg in o.__args__] + + if len(args) == 0: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python < 3.9 + return origin_typename + + return f'{origin_typename}[{",".join(args)}]' + else: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python 3.9+ + return origin_typename + + # Common case: this is a regular module name like 'foo.bar.baz' + return add_global(typename, o) + + def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: + + def _get_repr(arg): + # Handle NamedTuples (if it has `_fields`) via add_global. + if isinstance(arg, tuple) and hasattr(arg, '_fields'): + qualified_name = _get_qualified_name(type(arg)) + global_name = add_global(qualified_name, type(arg)) + return f"{global_name}{repr(tuple(arg))}" + return repr(arg) + + args_s = ', '.join(_get_repr(a) for a in args) + kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) + if args_s and kwargs_s: + return f'{args_s}, {kwargs_s}' + return args_s or kwargs_s + + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + # NOTE: we add a variable to distinguish body and ckpt_func + def delete_unused_values(user: Node, body): + """ + Delete values after their last use. This ensures that values that are + not used in the remainder of the code are freed and the memory usage + of the code is optimal. + """ + if user.op == 'placeholder': + return + if user.op == 'output': + body.append('\n') + return + nodes_to_delete = user_to_last_uses.get(user, []) + if len(nodes_to_delete): + to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) + body.append(f'; {to_delete_str}\n') + else: + body.append('\n') + + # NOTE: we add a variable to distinguish body and ckpt_func + def emit_node(node: Node, body): + maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' + if node.op == 'placeholder': + assert isinstance(node.target, str) + maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' + free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') + raw_name = node.target.replace('*', '') + if raw_name != repr(node): + body.append(f'{repr(node)} = {raw_name}\n') + return + elif node.op == 'call_method': + assert isinstance(node.target, str) + body.append( + f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' + f'({_format_args(node.args[1:], node.kwargs)})') + return + elif node.op == 'call_function': + assert callable(node.target) + # pretty print operators + if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: + assert isinstance(node.args, tuple) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') + return + + # pretty print inplace operators; required for jit.script to work properly + # not currently supported in normal FX graphs, but generated by torchdynamo + if node.target.__module__ == '_operator' and node.target.__name__ in inplace_methods: + body.append(f'{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; ' + f'{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}') + return + + qualified_name = _get_qualified_name(node.target) + global_name = add_global(qualified_name, node.target) + # special case for getattr: node.args could be 2-argument or 3-argument + # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value + if global_name == 'getattr' and \ + isinstance(node.args, tuple) and \ + isinstance(node.args[1], str) and \ + node.args[1].isidentifier() and \ + len(node.args) == 2: + body.append( + f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') + return + body.append( + f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') + if node.meta.get('is_wrapped', False): + wrapped_fns.setdefault(global_name) + return + elif node.op == 'call_module': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + return + elif node.op == 'get_attr': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + return + elif node.op == 'output': + if node.type is not None: + maybe_return_annotation[0] = f" -> {type_repr(node.type)}" + body.append(self.generate_output(node.args[0])) + return + raise NotImplementedError(f'node: {node.op} {node.target}') + + # Modified for activation checkpointing + ckpt_func = [] + + # if any node has a list of labels for activation_checkpoint, we + # will use nested type of activation checkpoint codegen + if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in nodes): + emit_code_with_nested_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) + else: + emit_code_with_activation_checkpoint(body, ckpt_func, nodes, emit_node, delete_unused_values) + + if len(body) == 0: + # If the Graph has no non-placeholder nodes, no lines for the body + # have been emitted. To continue to have valid Python code, emit a + # single pass statement + body.append('pass\n') + + if len(wrapped_fns) > 0: + wrap_name = add_global('wrap', torch.fx.wrap) + wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + else: + wrap_stmts = '' + + if self._body_transformer: + body = self._body_transformer(body) + + for name, value in self.additional_globals(): + add_global(name, value) + + # as we need colossalai.utils.checkpoint, we need to import colossalai + # in forward function + prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) + prologue = ''.join(ckpt_func) + prologue + prologue = prologue + + code = ''.join(body) + code = '\n'.join(' ' + line for line in code.split('\n')) + fn_code = f""" +{wrap_stmts} + +{prologue} +{code}""" + return PythonCode(fn_code, globals_) + +else: + + def python_code_with_activation_checkpoint(self, root_module: str, namespace: _Namespace) -> PythonCode: + """ + This method is copied from the _python_code of torch.fx.graph.Graph. Modifications are made so that it can generate + code for activation checkpoint. + """ + free_vars: List[str] = [] + body: List[str] = [] + globals_: Dict[str, Any] = {} + wrapped_fns: Dict[str, None] = {} + + # Wrap string in list to pass by reference + maybe_return_annotation: List[str] = [''] + + def add_global(name_hint: str, obj: Any): + """Add an obj to be tracked as a global. + + We call this for names that reference objects external to the + Graph, like functions or types. + + Returns: the global name that should be used to reference 'obj' in generated source. + """ + if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + # HACK: workaround for how torch custom ops are registered. We + # can't import them like normal modules so they must retain their + # fully qualified name. + return _get_qualified_name(obj) + + # normalize the name hint to get a proper identifier + global_name = namespace.create_name(name_hint, obj) + + if global_name in globals_: + assert globals_[global_name] is obj + return global_name + globals_[global_name] = obj + return global_name + + # set _custom_builtins here so that we needn't import colossalai in forward + _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) + + # Pre-fill the globals table with registered builtins. + for name, (_, obj) in _custom_builtins.items(): + add_global(name, obj) + + def type_repr(o: Any): + if o == (): + # Empty tuple is used for empty tuple type annotation Tuple[()] + return '()' + + typename = _type_repr(o) + + # This is a generic type, e.g. typing.List[torch.Tensor] + if hasattr(o, '__origin__'): + origin_type = _origin_type_map.get(o.__origin__, o.__origin__) + origin_typename = add_global(_type_repr(origin_type), origin_type) + + # Assign global names for each of the inner type variables. + args = [type_repr(arg) for arg in o.__args__] + + return f'{origin_typename}[{",".join(args)}]' + + # Common case: this is a regular module name like 'foo.bar.baz' + return add_global(typename, o) + + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(self.nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + # NOTE: we add a variable to distinguish body and ckpt_func + def delete_unused_values(user: Node, body): + """ + Delete values after their last use. This ensures that values that are + not used in the remainder of the code are freed and the memory usage + of the code is optimal. + """ + if user.op == 'placeholder': + return + if user.op == 'output': + body.append('\n') + return + nodes_to_delete = user_to_last_uses.get(user, []) + if len(nodes_to_delete): + to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) + body.append(f'; {to_delete_str}\n') + else: + body.append('\n') + + # NOTE: we add a variable to distinguish body and ckpt_func + def emit_node(node: Node, body): + maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' + if node.op == 'placeholder': + assert isinstance(node.target, str) + maybe_default_arg = '' if not node.args else f' = {repr(node.args[0])}' + free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') + raw_name = node.target.replace('*', '') + if raw_name != repr(node): + body.append(f'{repr(node)} = {raw_name}\n') + return + elif node.op == 'call_method': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}' + f'({_format_args(node.args[1:], node.kwargs)})') + return + elif node.op == 'call_function': + assert callable(node.target) + # pretty print operators + if node.target.__module__ == '_operator' and node.target.__name__ in magic_methods: + assert isinstance(node.args, tuple) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}') + return + qualified_name = _get_qualified_name(node.target) + global_name = add_global(qualified_name, node.target) + # special case for getattr: node.args could be 2-argument or 3-argument + # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value + if global_name == 'getattr' and \ + isinstance(node.args, tuple) and \ + isinstance(node.args[1], str) and \ + node.args[1].isidentifier() and \ + len(node.args) == 2: + body.append( + f'{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}') + return + body.append( + f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') + if node.meta.get('is_wrapped', False): + wrapped_fns.setdefault(global_name) + return + elif node.op == 'call_module': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = ' + f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + return + elif node.op == 'get_attr': + assert isinstance(node.target, str) + body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + return + elif node.op == 'output': + if node.type is not None: + maybe_return_annotation[0] = f" -> {type_repr(node.type)}" + if self._pytree_info is None: + body.append(f'return {repr(node.args[0])}') + else: + body.append(f'return pytree.tree_unflatten({repr(node.args[0])}, self._out_spec)') + return + raise NotImplementedError(f'node: {node.op} {node.target}') + + # Modified for activation checkpointing + ckpt_func = [] + + # if any node has a list of labels for activation_checkpoint, we + # will use nested type of activation checkpoint codegen + if any(isinstance(getattr(node, "activation_checkpoint", None), Iterable) for node in self.nodes): + emit_code_with_nested_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) + else: + emit_code_with_activation_checkpoint(body, ckpt_func, self.nodes, emit_node, delete_unused_values) + + if len(body) == 0: + # If the Graph has no non-placeholder nodes, no lines for the body + # have been emitted. To continue to have valid Python code, emit a + # single pass statement + body.append('pass\n') + if self._pytree_info is not None: + orig_args = self._pytree_info.orig_args + has_orig_self = (orig_args[0] == 'self') + if has_orig_self: + free_vars.insert(0, 'self') + if len(free_vars) > 0: # pytree has placeholders in it + body.insert( + 0, + f"{', '.join(free_vars)}, = fx_pytree.tree_flatten_spec([{', '.join(orig_args)}], self._in_spec)\n") + else: + orig_args = free_vars + + if len(wrapped_fns) > 0: + wrap_name = add_global('wrap', torch.fx.wrap) + wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + else: + wrap_stmts = '' + + ckpt_func = ''.join(ckpt_func) + + # If the original function didn't have self as its first argument, we + # would have added it. + if len(orig_args) == 0 or orig_args[0] != 'self': + orig_args.insert(0, 'self') + code = ''.join(body) + code = '\n'.join(' ' + line for line in code.split('\n')) + + # as we need colossalai.utils.checkpoint, we need to import colossalai + # in forward function + fn_code = f""" +{wrap_stmts} + +{ckpt_func} +def forward({', '.join(orig_args)}){maybe_return_annotation[0]}: +{code}""" + return PythonCode(fn_code, globals_) diff --git a/chunk_codegen_run.py b/chunk_codegen_run.py new file mode 100644 index 000000000..9ac399a29 --- /dev/null +++ b/chunk_codegen_run.py @@ -0,0 +1,177 @@ +import copy +import torch +import torch.nn.functional as F +import pytest +import torch.multiprocessing as mp +from torch.fx import GraphModule +from colossalai.fx import ColoTracer +import colossalai +from colossalai.utils import free_port +from colossalai.core import global_context as gpc +from colossalai.fx.graph_module import ColoGraphModule + +try: + from chunk_codegen import ActivationCheckpointCodeGen + with_codegen = True +except: + # fall back to older pytorch version + from chunk_codegen import python_code_with_activation_checkpoint + with_codegen = False + + +class MyNet(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + self.linear0 = torch.nn.Linear(4, 4) + self.linear1 = torch.nn.Linear(4, 4) + self.linear2 = torch.nn.Linear(4, 4) + self.linear3 = torch.nn.Linear(4, 4) + self.linear4 = torch.nn.Linear(4, 4) + self.linear5 = torch.nn.Linear(4, 4) + self.linear6 = torch.nn.Linear(4, 4) + + def forward(self, x): + x = self.linear0(x) + x = self.linear1(x) + x = self.linear2(x) + x = self.linear3(x) + x = self.linear4(x) + x = self.linear5(x) + x = self.linear6(x) + return x + + +def _is_all_gradient_close(m: torch.nn.Module, gm: GraphModule) -> bool: + for m_p, gm_p in zip(m.parameters(), gm.parameters()): + if not torch.allclose(m_p.grad, gm_p.grad): + return False + return True + + +def _test_fwd_and_bwd(model: torch.nn.Module, gm: ColoGraphModule, data: torch.Tensor): + + # test forward + non_fx_out = model(data) + fx_out = gm(data) + assert torch.equal(non_fx_out, fx_out), "fx_out doesn't comply with original output" + + # test barckward + loss0 = non_fx_out.sum() + loss0.backward() + loss1 = fx_out.sum() + loss1.backward() + assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one" + + +def _run_offload_codegen(rank): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + + # build model and input + model = MyNet().cuda() + data = torch.rand(4, 4).cuda() + + # trace the module and replace codegen + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(model) + codegen = ActivationCheckpointCodeGen() + graph.set_codegen(codegen) + + # annotate the activation offload part + # also annotate the activation_checkpoint so we could test both types + # of input offload + for node in graph.nodes: + if node.name == "linear0": + setattr(node, "activation_offload", [0, True, False]) + if node.name == "linear1": + setattr(node, "activation_offload", [0, True, False]) + if node.name == "linear2": + setattr(node, "activation_offload", [1, True, True]) + if node.name == "linear4": + setattr(node, "activation_offload", [2, False, True]) + if node.name == "linear5": + setattr(node, "activation_checkpoint", [0]) + setattr(node, "activation_offload", True) + + gm = ColoGraphModule(copy.deepcopy(model), graph) + gm.recompile() + + # assert we have all the components + code = graph.python_code("self").src + assert "def pack_hook_input(self, x):" in code and \ + "def unpack_hook(self, packed):" in code and \ + "def pack_hook_no_input(self, x):" in code and \ + "setattr(x, 'offload', True)" in code and \ + "setattr(linear3, 'offload', False)" in code and \ + "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \ + "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \ + "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \ + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code + + _test_fwd_and_bwd(model, gm, data) + gpc.destroy() + + +@pytest.mark.skipif(not with_codegen, reason='torch version is lower than 1.12.0') +def test_act_ckpt_codegen(): + mp.spawn(_run_offload_codegen, nprocs=1) + + +def _run_offload_codegen_torch11(rank): + # launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly + colossalai.launch(config={}, rank=rank, world_size=1, host='localhost', port=free_port(), backend='nccl') + + # build model and input + model = MyNet().cuda() + data = torch.rand(4, 4).cuda() + + # trace the module and replace codegen + tracer = ColoTracer(trace_act_ckpt=True) + graph = tracer.trace(model) + + # replace a bound method of an object + graph._python_code = python_code_with_activation_checkpoint.__get__(graph) + + # annotate the activation offload part + # also annotate the activation_checkpoint so we could test both types + # of input offload + for node in graph.nodes: + if node.name == "linear0": + setattr(node, "activation_offload", [0, True, False]) + if node.name == "linear1": + setattr(node, "activation_offload", [0, True, False]) + if node.name == "linear2": + setattr(node, "activation_offload", [1, True, True]) + if node.name == "linear4": + setattr(node, "activation_offload", [2, False, True]) + if node.name == "linear5": + setattr(node, "activation_checkpoint", [0]) + setattr(node, "activation_offload", True) + + gm = ColoGraphModule(copy.deepcopy(model), graph) + gm.recompile() + + # assert we have all the components + code = graph.python_code("self").src + assert "def pack_hook_input(self, x):" in code and \ + "def unpack_hook(self, packed):" in code and \ + "def pack_hook_no_input(self, x):" in code and \ + "setattr(x, 'offload', True)" in code and \ + "setattr(linear3, 'offload', False)" in code and \ + "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):" in code and \ + "with torch.autograd.graph.save_on_cpu(pin_memory=True):" in code and \ + "with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):" in code and \ + "colossalai.utils.activation_checkpoint.checkpoint(self.checkpoint_0, True, linear4, use_reentrant=False)" in code + + _test_fwd_and_bwd(model, gm, data) + gpc.destroy() + + +@pytest.mark.skip(reason="currently torch11 ColoGraphModule is not implemented") +def test_act_ckpt_python_code_torch11(): + mp.spawn(_run_offload_codegen_torch11, nprocs=1) + + +if __name__ == "__main__": + _run_offload_codegen(0)