From 19cc64b1d39529bde502f9507d20770430f6e3af Mon Sep 17 00:00:00 2001 From: oahzxl Date: Mon, 9 Jan 2023 16:06:58 +0800 Subject: [PATCH] remove autochunk_available --- colossalai/autochunk/autochunk_codegen.py | 490 +++++++++++----------- 1 file changed, 239 insertions(+), 251 deletions(-) diff --git a/colossalai/autochunk/autochunk_codegen.py b/colossalai/autochunk/autochunk_codegen.py index 0db2e5908..9ec59477b 100644 --- a/colossalai/autochunk/autochunk_codegen.py +++ b/colossalai/autochunk/autochunk_codegen.py @@ -16,13 +16,9 @@ from torch.fx.graph import ( from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg import colossalai - from .search_chunk import SearchChunk from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape -CODEGEN_AVAILABLE = True -__all__ = ["AutoChunkCodeGen"] - def _gen_chunk_slice_dim(chunk_dim, chunk_idx_name, shape): new_shape = "[" @@ -222,287 +218,279 @@ def emit_code_with_chunk( node_idx += 1 -if CODEGEN_AVAILABLE: - - class AutoChunkCodeGen(CodeGen): - def __init__(self, meta_graph, max_memory=None, print_mem=False): - super().__init__() - self.meta_graph = meta_graph - self.max_memory = max_memory - self.meta_node = list(meta_graph.graph.nodes) - # find the chunk regions - self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem) - self.chunk_infos = self.search_chunk.search_region() +class AutoChunkCodeGen(CodeGen): + def __init__(self, meta_graph, max_memory=None, print_mem=False): + super().__init__() + self.meta_graph = meta_graph + self.max_memory = max_memory + self.meta_node = list(meta_graph.graph.nodes) + # find the chunk regions + self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem) + self.chunk_infos = self.search_chunk.search_region() - def _gen_python_code( - self, nodes, root_module: str, namespace: _Namespace - ) -> PythonCode: - free_vars: List[str] = [] - body: List[str] = [] - globals_: Dict[str, Any] = {} - wrapped_fns: Dict[str, None] = {} + def _gen_python_code( + self, nodes, root_module: str, namespace: _Namespace + ) -> PythonCode: + free_vars: List[str] = [] + body: List[str] = [] + globals_: Dict[str, Any] = {} + wrapped_fns: Dict[str, None] = {} - # Wrap string in list to pass by reference - maybe_return_annotation: List[str] = [""] + # Wrap string in list to pass by reference + maybe_return_annotation: List[str] = [""] - def add_global(name_hint: str, obj: Any): - """Add an obj to be tracked as a global. + def add_global(name_hint: str, obj: Any): + """Add an obj to be tracked as a global. - We call this for names that reference objects external to the - Graph, like functions or types. + We call this for names that reference objects external to the + Graph, like functions or types. - Returns: the global name that should be used to reference 'obj' in generated source. - """ - if ( - _is_from_torch(obj) and obj != torch.device - ): # to support registering torch.device - # HACK: workaround for how torch custom ops are registered. We - # can't import them like normal modules so they must retain their - # fully qualified name. - return _get_qualified_name(obj) - - # normalize the name hint to get a proper identifier - global_name = namespace.create_name(name_hint, obj) - - if global_name in globals_: - assert globals_[global_name] is obj - return global_name - globals_[global_name] = obj + Returns: the global name that should be used to reference 'obj' in generated source. + """ + if ( + _is_from_torch(obj) and obj != torch.device + ): # to support registering torch.device + # HACK: workaround for how torch custom ops are registered. We + # can't import them like normal modules so they must retain their + # fully qualified name. + return _get_qualified_name(obj) + + # normalize the name hint to get a proper identifier + global_name = namespace.create_name(name_hint, obj) + + if global_name in globals_: + assert globals_[global_name] is obj return global_name + globals_[global_name] = obj + return global_name - # set _custom_builtins here so that we needn't import colossalai in forward - _custom_builtins["colossalai"] = _CustomBuiltin( - "import colossalai", colossalai - ) - - # Pre-fill the globals table with registered builtins. - for name, (_, obj) in _custom_builtins.items(): - add_global(name, obj) + # set _custom_builtins here so that we needn't import colossalai in forward + _custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) - def type_repr(o: Any): - if o == (): - # Empty tuple is used for empty tuple type annotation Tuple[()] - return "()" + # Pre-fill the globals table with registered builtins. + for name, (_, obj) in _custom_builtins.items(): + add_global(name, obj) - typename = _type_repr(o) + def type_repr(o: Any): + if o == (): + # Empty tuple is used for empty tuple type annotation Tuple[()] + return "()" - if hasattr(o, "__origin__"): - # This is a generic type, e.g. typing.List[torch.Tensor] - origin_type = _origin_type_map.get(o.__origin__, o.__origin__) - origin_typename = add_global(_type_repr(origin_type), origin_type) + typename = _type_repr(o) - if hasattr(o, "__args__"): - # Assign global names for each of the inner type variables. - args = [type_repr(arg) for arg in o.__args__] + if hasattr(o, "__origin__"): + # This is a generic type, e.g. typing.List[torch.Tensor] + origin_type = _origin_type_map.get(o.__origin__, o.__origin__) + origin_typename = add_global(_type_repr(origin_type), origin_type) - if len(args) == 0: - # Bare type, such as `typing.Tuple` with no subscript - # This code-path used in Python < 3.9 - return origin_typename + if hasattr(o, "__args__"): + # Assign global names for each of the inner type variables. + args = [type_repr(arg) for arg in o.__args__] - return f'{origin_typename}[{",".join(args)}]' - else: + if len(args) == 0: # Bare type, such as `typing.Tuple` with no subscript - # This code-path used in Python 3.9+ + # This code-path used in Python < 3.9 return origin_typename - # Common case: this is a regular module name like 'foo.bar.baz' - return add_global(typename, o) - - def _format_args( - args: Tuple[Argument, ...], kwargs: Dict[str, Argument] - ) -> str: - def _get_repr(arg): - # Handle NamedTuples (if it has `_fields`) via add_global. - if isinstance(arg, tuple) and hasattr(arg, "_fields"): - qualified_name = _get_qualified_name(type(arg)) - global_name = add_global(qualified_name, type(arg)) - return f"{global_name}{repr(tuple(arg))}" - return repr(arg) - - args_s = ", ".join(_get_repr(a) for a in args) - kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items()) - if args_s and kwargs_s: - return f"{args_s}, {kwargs_s}" - return args_s or kwargs_s - - # Run through reverse nodes and record the first instance of a use - # of a given node. This represents the *last* use of the node in the - # execution order of the program, which we will use to free unused - # values - node_to_last_use: Dict[Node, Node] = {} - user_to_last_uses: Dict[Node, List[Node]] = {} - - def register_last_uses(n: Node, user: Node): - if n not in node_to_last_use: - node_to_last_use[n] = user - user_to_last_uses.setdefault(user, []).append(n) - - for node in reversed(nodes): - map_arg(node.args, lambda n: register_last_uses(n, node)) - map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - - delete_free_var_from_last_use(user_to_last_uses) - - # NOTE: we add a variable to distinguish body and ckpt_func - def delete_unused_values(user: Node, body, to_keep=[]): - """ - Delete values after their last use. This ensures that values that are - not used in the remainder of the code are freed and the memory usage - of the code is optimal. - """ - if user.op == "placeholder": - return - if user.op == "output": - body.append("\n") - return - nodes_to_delete = user_to_last_uses.get(user, []) - nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep] - if len(nodes_to_delete): - to_delete_str = " = ".join( - [repr(n) for n in nodes_to_delete] + ["None"] - ) - body.append(f"; {to_delete_str}\n") + return f'{origin_typename}[{",".join(args)}]' else: - body.append("\n") + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python 3.9+ + return origin_typename + + # Common case: this is a regular module name like 'foo.bar.baz' + return add_global(typename, o) + + def _format_args( + args: Tuple[Argument, ...], kwargs: Dict[str, Argument] + ) -> str: + def _get_repr(arg): + # Handle NamedTuples (if it has `_fields`) via add_global. + if isinstance(arg, tuple) and hasattr(arg, "_fields"): + qualified_name = _get_qualified_name(type(arg)) + global_name = add_global(qualified_name, type(arg)) + return f"{global_name}{repr(tuple(arg))}" + return repr(arg) + + args_s = ", ".join(_get_repr(a) for a in args) + kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items()) + if args_s and kwargs_s: + return f"{args_s}, {kwargs_s}" + return args_s or kwargs_s + + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(nodes): + map_arg(node.args, lambda n: register_last_uses(n, node)) + map_arg(node.kwargs, lambda n: register_last_uses(n, node)) + + delete_free_var_from_last_use(user_to_last_uses) + + # NOTE: we add a variable to distinguish body and ckpt_func + def delete_unused_values(user: Node, body, to_keep=[]): + """ + Delete values after their last use. This ensures that values that are + not used in the remainder of the code are freed and the memory usage + of the code is optimal. + """ + if user.op == "placeholder": + return + if user.op == "output": + body.append("\n") + return + nodes_to_delete = user_to_last_uses.get(user, []) + nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep] + if len(nodes_to_delete): + to_delete_str = " = ".join( + [repr(n) for n in nodes_to_delete] + ["None"] + ) + body.append(f"; {to_delete_str}\n") + else: + body.append("\n") - # NOTE: we add a variable to distinguish body and ckpt_func - def emit_node(node: Node, body): - maybe_type_annotation = ( - "" if node.type is None else f" : {type_repr(node.type)}" + # NOTE: we add a variable to distinguish body and ckpt_func + def emit_node(node: Node, body): + maybe_type_annotation = ( + "" if node.type is None else f" : {type_repr(node.type)}" + ) + if node.op == "placeholder": + assert isinstance(node.target, str) + maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}" + free_vars.append( + f"{node.target}{maybe_type_annotation}{maybe_default_arg}" ) - if node.op == "placeholder": - assert isinstance(node.target, str) - maybe_default_arg = ( - "" if not node.args else f" = {repr(node.args[0])}" - ) - free_vars.append( - f"{node.target}{maybe_type_annotation}{maybe_default_arg}" - ) - raw_name = node.target.replace("*", "") - if raw_name != repr(node): - body.append(f"{repr(node)} = {raw_name}\n") - return - elif node.op == "call_method": - assert isinstance(node.target, str) - body.append( - f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" - f"({_format_args(node.args[1:], node.kwargs)})" - ) - return - elif node.op == "call_function": - assert callable(node.target) - # pretty print operators - if ( - node.target.__module__ == "_operator" - and node.target.__name__ in magic_methods - ): - assert isinstance(node.args, tuple) - body.append( - f"{repr(node)}{maybe_type_annotation} = " - f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" - ) - return - - # pretty print inplace operators; required for jit.script to work properly - # not currently supported in normal FX graphs, but generated by torchdynamo - if ( - node.target.__module__ == "_operator" - and node.target.__name__ in inplace_methods - ): - body.append( - f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " - f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}" - ) - return - - qualified_name = _get_qualified_name(node.target) - global_name = add_global(qualified_name, node.target) - # special case for getattr: node.args could be 2-argument or 3-argument - # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if ( - global_name == "getattr" - and isinstance(node.args, tuple) - and isinstance(node.args[1], str) - and node.args[1].isidentifier() - and len(node.args) == 2 - ): - body.append( - f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" - ) - return + raw_name = node.target.replace("*", "") + if raw_name != repr(node): + body.append(f"{repr(node)} = {raw_name}\n") + return + elif node.op == "call_method": + assert isinstance(node.target, str) + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" + f"({_format_args(node.args[1:], node.kwargs)})" + ) + return + elif node.op == "call_function": + assert callable(node.target) + # pretty print operators + if ( + node.target.__module__ == "_operator" + and node.target.__name__ in magic_methods + ): + assert isinstance(node.args, tuple) body.append( - f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" + f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" ) - if node.meta.get("is_wrapped", False): - wrapped_fns.setdefault(global_name) return - elif node.op == "call_module": - assert isinstance(node.target, str) + + # pretty print inplace operators; required for jit.script to work properly + # not currently supported in normal FX graphs, but generated by torchdynamo + if ( + node.target.__module__ == "_operator" + and node.target.__name__ in inplace_methods + ): body.append( - f"{repr(node)}{maybe_type_annotation} = " - f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " + f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}" ) return - elif node.op == "get_attr": - assert isinstance(node.target, str) + + qualified_name = _get_qualified_name(node.target) + global_name = add_global(qualified_name, node.target) + # special case for getattr: node.args could be 2-argument or 3-argument + # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value + if ( + global_name == "getattr" + and isinstance(node.args, tuple) + and isinstance(node.args[1], str) + and node.args[1].isidentifier() + and len(node.args) == 2 + ): body.append( - f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}" + f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}" ) return - elif node.op == "output": - if node.type is not None: - maybe_return_annotation[0] = f" -> {type_repr(node.type)}" - body.append(self.generate_output(node.args[0])) - return - raise NotImplementedError(f"node: {node.op} {node.target}") - - # Modified for activation checkpointing - ckpt_func = [] - - # if any node has a list of labels for activation_checkpoint, we - # will use nested type of activation checkpoint codegen - emit_code_with_chunk( - body, - nodes, - emit_node, - delete_unused_values, - self.search_chunk, - self.chunk_infos, - ) - - if len(body) == 0: - # If the Graph has no non-placeholder nodes, no lines for the body - # have been emitted. To continue to have valid Python code, emit a - # single pass statement - body.append("pass\n") - - if len(wrapped_fns) > 0: - wrap_name = add_global("wrap", torch.fx.wrap) - wrap_stmts = "\n".join( - [f'{wrap_name}("{name}")' for name in wrapped_fns] + body.append( + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" ) - else: - wrap_stmts = "" + if node.meta.get("is_wrapped", False): + wrapped_fns.setdefault(global_name) + return + elif node.op == "call_module": + assert isinstance(node.target, str) + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + ) + return + elif node.op == "get_attr": + assert isinstance(node.target, str) + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}" + ) + return + elif node.op == "output": + if node.type is not None: + maybe_return_annotation[0] = f" -> {type_repr(node.type)}" + body.append(self.generate_output(node.args[0])) + return + raise NotImplementedError(f"node: {node.op} {node.target}") + + # Modified for activation checkpointing + ckpt_func = [] + + # if any node has a list of labels for activation_checkpoint, we + # will use nested type of activation checkpoint codegen + emit_code_with_chunk( + body, + nodes, + emit_node, + delete_unused_values, + self.search_chunk, + self.chunk_infos, + ) + + if len(body) == 0: + # If the Graph has no non-placeholder nodes, no lines for the body + # have been emitted. To continue to have valid Python code, emit a + # single pass statement + body.append("pass\n") + + if len(wrapped_fns) > 0: + wrap_name = add_global("wrap", torch.fx.wrap) + wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + else: + wrap_stmts = "" - if self._body_transformer: - body = self._body_transformer(body) + if self._body_transformer: + body = self._body_transformer(body) - for name, value in self.additional_globals(): - add_global(name, value) + for name, value in self.additional_globals(): + add_global(name, value) - # as we need colossalai.utils.checkpoint, we need to import colossalai - # in forward function - prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) - prologue = "".join(ckpt_func) + prologue - prologue = prologue + # as we need colossalai.utils.checkpoint, we need to import colossalai + # in forward function + prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) + prologue = "".join(ckpt_func) + prologue + prologue = prologue - code = "".join(body) - code = "\n".join(" " + line for line in code.split("\n")) - fn_code = f""" + code = "".join(body) + code = "\n".join(" " + line for line in code.split("\n")) + fn_code = f""" {wrap_stmts} {prologue} {code}""" - # print(fn_code) - return PythonCode(fn_code, globals_) + # print(fn_code) + return PythonCode(fn_code, globals_)