diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 73b6bf524..1ee1d818a 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -1,21 +1,25 @@ from typing import Any, Dict, Iterable, List, Tuple import torch -from torch.fx.graph import ( - CodeGen, - PythonCode, - _custom_builtins, - _CustomBuiltin, - _format_target, - _is_from_torch, - _Namespace, - _origin_type_map, - inplace_methods, - magic_methods, -) -from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg import colossalai +from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE + +if CODEGEN_AVAILABLE: + from torch.fx.graph import ( + CodeGen, + PythonCode, + _custom_builtins, + _CustomBuiltin, + _format_target, + _is_from_torch, + _Namespace, + _origin_type_map, + inplace_methods, + magic_methods, + ) + from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg + from .search_chunk import SearchChunk from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape @@ -96,7 +100,7 @@ def _gen_loop_end( Returns: context (str): generated str - """ + """ chunk_outputs_name = chunk_outputs.name chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list) chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape @@ -302,279 +306,287 @@ def emit_code_with_chunk( node_idx += 1 -class AutoChunkCodeGen(CodeGen): - def __init__(self, meta_graph, max_memory=None, print_mem=False): - super().__init__() - self.meta_graph = meta_graph - self.max_memory = max_memory - self.meta_node = list(meta_graph.graph.nodes) - # find the chunk regions - self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem) - self.chunk_infos = self.search_chunk.search_region() +if CODEGEN_AVAILABLE: - def _gen_python_code( - self, nodes, root_module: str, namespace: _Namespace - ) -> PythonCode: - free_vars: List[str] = [] - body: List[str] = [] - globals_: Dict[str, Any] = {} - wrapped_fns: Dict[str, None] = {} + class AutoChunkCodeGen(CodeGen): + def __init__(self, meta_graph, max_memory=None, print_mem=False): + super().__init__() + self.meta_graph = meta_graph + self.max_memory = max_memory + self.meta_node = list(meta_graph.graph.nodes) + # find the chunk regions + self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem) + self.chunk_infos = self.search_chunk.search_region() - # Wrap string in list to pass by reference - maybe_return_annotation: List[str] = [""] + def _gen_python_code( + self, nodes, root_module: str, namespace: _Namespace + ) -> PythonCode: + free_vars: List[str] = [] + body: List[str] = [] + globals_: Dict[str, Any] = {} + wrapped_fns: Dict[str, None] = {} - def add_global(name_hint: str, obj: Any): - """Add an obj to be tracked as a global. + # Wrap string in list to pass by reference + maybe_return_annotation: List[str] = [""] - We call this for names that reference objects external to the - Graph, like functions or types. + def add_global(name_hint: str, obj: Any): + """Add an obj to be tracked as a global. - Returns: the global name that should be used to reference 'obj' in generated source. - """ - if ( - _is_from_torch(obj) and obj != torch.device - ): # to support registering torch.device - # HACK: workaround for how torch custom ops are registered. We - # can't import them like normal modules so they must retain their - # fully qualified name. - return _get_qualified_name(obj) - - # normalize the name hint to get a proper identifier - global_name = namespace.create_name(name_hint, obj) - - if global_name in globals_: - assert globals_[global_name] is obj + We call this for names that reference objects external to the + Graph, like functions or types. + + Returns: the global name that should be used to reference 'obj' in generated source. + """ + if ( + _is_from_torch(obj) and obj != torch.device + ): # to support registering torch.device + # HACK: workaround for how torch custom ops are registered. We + # can't import them like normal modules so they must retain their + # fully qualified name. + return _get_qualified_name(obj) + + # normalize the name hint to get a proper identifier + global_name = namespace.create_name(name_hint, obj) + + if global_name in globals_: + assert globals_[global_name] is obj + return global_name + globals_[global_name] = obj return global_name - globals_[global_name] = obj - return global_name - # set _custom_builtins here so that we needn't import colossalai in forward - _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) + # set _custom_builtins here so that we needn't import colossalai in forward + _custom_builtins["colossalai"] = _CustomBuiltin( + "import colossalai", colossalai + ) + + # Pre-fill the globals table with registered builtins. + for name, (_, obj) in _custom_builtins.items(): + add_global(name, obj) - # Pre-fill the globals table with registered builtins. - for name, (_, obj) in _custom_builtins.items(): - add_global(name, obj) + def type_repr(o: Any): + if o == (): + # Empty tuple is used for empty tuple type annotation Tuple[()] + return "()" - def type_repr(o: Any): - if o == (): - # Empty tuple is used for empty tuple type annotation Tuple[()] - return "()" + typename = _type_repr(o) - typename = _type_repr(o) + if hasattr(o, "__origin__"): + # This is a generic type, e.g. typing.List[torch.Tensor] + origin_type = _origin_type_map.get(o.__origin__, o.__origin__) + origin_typename = add_global(_type_repr(origin_type), origin_type) - if hasattr(o, "__origin__"): - # This is a generic type, e.g. typing.List[torch.Tensor] - origin_type = _origin_type_map.get(o.__origin__, o.__origin__) - origin_typename = add_global(_type_repr(origin_type), origin_type) + if hasattr(o, "__args__"): + # Assign global names for each of the inner type variables. + args = [type_repr(arg) for arg in o.__args__] - if hasattr(o, "__args__"): - # Assign global names for each of the inner type variables. - args = [type_repr(arg) for arg in o.__args__] + if len(args) == 0: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python < 3.9 + return origin_typename - if len(args) == 0: + return f'{origin_typename}[{",".join(args)}]' + else: # Bare type, such as `typing.Tuple` with no subscript - # This code-path used in Python < 3.9 + # This code-path used in Python 3.9+ return origin_typename - return f'{origin_typename}[{",".join(args)}]' + # Common case: this is a regular module name like 'foo.bar.baz' + return add_global(typename, o) + + def _format_args( + args: Tuple[Argument, ...], kwargs: Dict[str, Argument] + ) -> str: + def _get_repr(arg): + # Handle NamedTuples (if it has `_fields`) via add_global. + if isinstance(arg, tuple) and hasattr(arg, "_fields"): + qualified_name = _get_qualified_name(type(arg)) + global_name = add_global(qualified_name, type(arg)) + return f"{global_name}{repr(tuple(arg))}" + return repr(arg) + + args_s = ", ".join(_get_repr(a) for a in args) + kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items()) + if args_s and kwargs_s: + return f"{args_s}, {kwargs_s}" + return args_s or kwargs_s + + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + delete_free_var_from_last_use(user_to_last_uses) + + # NOTE: we add a variable to distinguish body and ckpt_func + def delete_unused_values(user: Node, body, to_keep=[]): + """ + Delete values after their last use. This ensures that values that are + not used in the remainder of the code are freed and the memory usage + of the code is optimal. + """ + if user.op == "placeholder": + return + if user.op == "output": + body.append("\n") + return + nodes_to_delete = user_to_last_uses.get(user, []) + nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep] + if len(nodes_to_delete): + to_delete_str = " = ".join( + [repr(n) for n in nodes_to_delete] + ["None"] + ) + body.append(f"; {to_delete_str}\n") else: - # Bare type, such as `typing.Tuple` with no subscript - # This code-path used in Python 3.9+ - return origin_typename - - # Common case: this is a regular module name like 'foo.bar.baz' - return add_global(typename, o) - - def _format_args( - args: Tuple[Argument, ...], kwargs: Dict[str, Argument] - ) -> str: - def _get_repr(arg): - # Handle NamedTuples (if it has `_fields`) via add_global. - if isinstance(arg, tuple) and hasattr(arg, "_fields"): - qualified_name = _get_qualified_name(type(arg)) - global_name = add_global(qualified_name, type(arg)) - return f"{global_name}{repr(tuple(arg))}" - return repr(arg) - - args_s = ", ".join(_get_repr(a) for a in args) - kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items()) - if args_s and kwargs_s: - return f"{args_s}, {kwargs_s}" - return args_s or kwargs_s - - # Run through reverse nodes and record the first instance of a use - # of a given node. This represents the *last* use of the node in the - # execution order of the program, which we will use to free unused - # values - node_to_last_use: Dict[Node, Node] = {} - user_to_last_uses: Dict[Node, List[Node]] = {} - - def register_last_uses(n: Node, user: Node): - if n not in node_to_last_use: - node_to_last_use[n] = user - user_to_last_uses.setdefault(user, []).append(n) - - for node in reversed(nodes): - map_arg(node.args, lambda n: register_last_uses(n, node)) - map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - - delete_free_var_from_last_use(user_to_last_uses) - - # NOTE: we add a variable to distinguish body and ckpt_func - def delete_unused_values(user: Node, body, to_keep=[]): - """ - Delete values after their last use. This ensures that values that are - not used in the remainder of the code are freed and the memory usage - of the code is optimal. - """ - if user.op == "placeholder": - return - if user.op == "output": - body.append("\n") - return - nodes_to_delete = user_to_last_uses.get(user, []) - nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep] - if len(nodes_to_delete): - to_delete_str = " = ".join( - [repr(n) for n in nodes_to_delete] + ["None"] - ) - body.append(f"; {to_delete_str}\n") - else: - body.append("\n") + body.append("\n") - # NOTE: we add a variable to distinguish body and ckpt_func - def emit_node(node: Node, body): - maybe_type_annotation = ( - "" if node.type is None else f" : {type_repr(node.type)}" - ) - if node.op == "placeholder": - assert isinstance(node.target, str) - maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}" - free_vars.append( - f"{node.target}{maybe_type_annotation}{maybe_default_arg}" + # NOTE: we add a variable to distinguish body and ckpt_func + def emit_node(node: Node, body): + maybe_type_annotation = ( + "" if node.type is None else f" : {type_repr(node.type)}" ) - raw_name = node.target.replace("*", "") - if raw_name != repr(node): - body.append(f"{repr(node)} = {raw_name}\n") - return - elif node.op == "call_method": - assert isinstance(node.target, str) - body.append( - f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" - f"({_format_args(node.args[1:], node.kwargs)})" - ) - return - elif node.op == "call_function": - assert callable(node.target) - # pretty print operators - if ( - node.target.__module__ == "_operator" - and node.target.__name__ in magic_methods - ): - assert isinstance(node.args, tuple) + if node.op == "placeholder": + assert isinstance(node.target, str) + maybe_default_arg = ( + "" if not node.args else f" = {repr(node.args[0])}" + ) + free_vars.append( + f"{node.target}{maybe_type_annotation}{maybe_default_arg}" + ) + raw_name = node.target.replace("*", "") + if raw_name != repr(node): + body.append(f"{repr(node)} = {raw_name}\n") + return + elif node.op == "call_method": + assert isinstance(node.target, str) body.append( - f"{repr(node)}{maybe_type_annotation} = " - f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" + f"({_format_args(node.args[1:], node.kwargs)})" ) return - - # pretty print inplace operators; required for jit.script to work properly - # not currently supported in normal FX graphs, but generated by torchdynamo - if ( - node.target.__module__ == "_operator" - and node.target.__name__ in inplace_methods - ): + elif node.op == "call_function": + assert callable(node.target) + # pretty print operators + if ( + node.target.__module__ == "_operator" + and node.target.__name__ in magic_methods + ): + assert isinstance(node.args, tuple) + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" + ) + return + + # pretty print inplace operators; required for jit.script to work properly + # not currently supported in normal FX graphs, but generated by torchdynamo + if ( + node.target.__module__ == "_operator" + and node.target.__name__ in inplace_methods + ): + body.append( + f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " + f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}" + ) + return + + qualified_name = _get_qualified_name(node.target) + global_name = add_global(qualified_name, node.target) + # special case for getattr: node.args could be 2-argument or 3-argument + # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value + if ( + global_name == "getattr" + and isinstance(node.args, tuple) + and isinstance(node.args[1], str) + and node.args[1].isidentifier() + and len(node.args) == 2 + ): + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" + ) + return body.append( - f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " - f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}" + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" ) + if node.meta.get("is_wrapped", False): + wrapped_fns.setdefault(global_name) return - - qualified_name = _get_qualified_name(node.target) - global_name = add_global(qualified_name, node.target) - # special case for getattr: node.args could be 2-argument or 3-argument - # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if ( - global_name == "getattr" - and isinstance(node.args, tuple) - and isinstance(node.args[1], str) - and node.args[1].isidentifier() - and len(node.args) == 2 - ): + elif node.op == "call_module": + assert isinstance(node.target, str) body.append( - f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" + f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" ) return - body.append( - f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" - ) - if node.meta.get("is_wrapped", False): - wrapped_fns.setdefault(global_name) - return - elif node.op == "call_module": - assert isinstance(node.target, str) - body.append( - f"{repr(node)}{maybe_type_annotation} = " - f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" - ) - return - elif node.op == "get_attr": - assert isinstance(node.target, str) - body.append( - f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}" + elif node.op == "get_attr": + assert isinstance(node.target, str) + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}" + ) + return + elif node.op == "output": + if node.type is not None: + maybe_return_annotation[0] = f" -> {type_repr(node.type)}" + body.append(self.generate_output(node.args[0])) + return + raise NotImplementedError(f"node: {node.op} {node.target}") + + # Modified for activation checkpointing + ckpt_func = [] + + # if any node has a list of labels for activation_checkpoint, we + # will use nested type of activation checkpoint codegen + emit_code_with_chunk( + body, + nodes, + emit_node, + delete_unused_values, + self.search_chunk, + self.chunk_infos, + ) + + if len(body) == 0: + # If the Graph has no non-placeholder nodes, no lines for the body + # have been emitted. To continue to have valid Python code, emit a + # single pass statement + body.append("pass\n") + + if len(wrapped_fns) > 0: + wrap_name = add_global("wrap", torch.fx.wrap) + wrap_stmts = "\n".join( + [f'{wrap_name}("{name}")' for name in wrapped_fns] ) - return - elif node.op == "output": - if node.type is not None: - maybe_return_annotation[0] = f" -> {type_repr(node.type)}" - body.append(self.generate_output(node.args[0])) - return - raise NotImplementedError(f"node: {node.op} {node.target}") - - # Modified for activation checkpointing - ckpt_func = [] - - # if any node has a list of labels for activation_checkpoint, we - # will use nested type of activation checkpoint codegen - emit_code_with_chunk( - body, - nodes, - emit_node, - delete_unused_values, - self.search_chunk, - self.chunk_infos, - ) - - if len(body) == 0: - # If the Graph has no non-placeholder nodes, no lines for the body - # have been emitted. To continue to have valid Python code, emit a - # single pass statement - body.append("pass\n") - - if len(wrapped_fns) > 0: - wrap_name = add_global("wrap", torch.fx.wrap) - wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns]) - else: - wrap_stmts = "" + else: + wrap_stmts = "" - if self._body_transformer: - body = self._body_transformer(body) + if self._body_transformer: + body = self._body_transformer(body) - for name, value in self.additional_globals(): - add_global(name, value) + for name, value in self.additional_globals(): + add_global(name, value) - # as we need colossalai.utils.checkpoint, we need to import colossalai - # in forward function - prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) - prologue = "".join(ckpt_func) + prologue - prologue = prologue + # as we need colossalai.utils.checkpoint, we need to import colossalai + # in forward function + prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) + prologue = "".join(ckpt_func) + prologue + prologue = prologue - code = "".join(body) - code = "\n".join(" " + line for line in code.split("\n")) - fn_code = f""" -{wrap_stmts} + code = "".join(body) + code = "\n".join(" " + line for line in code.split("\n")) + fn_code = f""" + {wrap_stmts} -{prologue} -{code}""" - # print(fn_code) - return PythonCode(fn_code, globals_) + {prologue} + {code}""" + # print(fn_code) + return PythonCode(fn_code, globals_)