add available

pull/2364/head
oahzxl 2023-01-10 10:56:39 +08:00
parent 615e7e68d9
commit a591d45b29
1 changed files with 266 additions and 254 deletions

View File

@ -1,21 +1,25 @@
from typing import Any, Dict, Iterable, List, Tuple from typing import Any, Dict, Iterable, List, Tuple
import torch import torch
from torch.fx.graph import (
CodeGen,
PythonCode,
_custom_builtins,
_CustomBuiltin,
_format_target,
_is_from_torch,
_Namespace,
_origin_type_map,
inplace_methods,
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
import colossalai import colossalai
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
if CODEGEN_AVAILABLE:
from torch.fx.graph import (
CodeGen,
PythonCode,
_custom_builtins,
_CustomBuiltin,
_format_target,
_is_from_torch,
_Namespace,
_origin_type_map,
inplace_methods,
magic_methods,
)
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
from .search_chunk import SearchChunk from .search_chunk import SearchChunk
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape
@ -96,7 +100,7 @@ def _gen_loop_end(
Returns: Returns:
context (str): generated str context (str): generated str
""" """
chunk_outputs_name = chunk_outputs.name chunk_outputs_name = chunk_outputs.name
chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list) chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list)
chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape
@ -302,279 +306,287 @@ def emit_code_with_chunk(
node_idx += 1 node_idx += 1
class AutoChunkCodeGen(CodeGen): if CODEGEN_AVAILABLE:
def __init__(self, meta_graph, max_memory=None, print_mem=False):
super().__init__()
self.meta_graph = meta_graph
self.max_memory = max_memory
self.meta_node = list(meta_graph.graph.nodes)
# find the chunk regions
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem)
self.chunk_infos = self.search_chunk.search_region()
def _gen_python_code( class AutoChunkCodeGen(CodeGen):
self, nodes, root_module: str, namespace: _Namespace def __init__(self, meta_graph, max_memory=None, print_mem=False):
) -> PythonCode: super().__init__()
free_vars: List[str] = [] self.meta_graph = meta_graph
body: List[str] = [] self.max_memory = max_memory
globals_: Dict[str, Any] = {} self.meta_node = list(meta_graph.graph.nodes)
wrapped_fns: Dict[str, None] = {} # find the chunk regions
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem)
self.chunk_infos = self.search_chunk.search_region()
# Wrap string in list to pass by reference def _gen_python_code(
maybe_return_annotation: List[str] = [""] self, nodes, root_module: str, namespace: _Namespace
) -> PythonCode:
free_vars: List[str] = []
body: List[str] = []
globals_: Dict[str, Any] = {}
wrapped_fns: Dict[str, None] = {}
def add_global(name_hint: str, obj: Any): # Wrap string in list to pass by reference
"""Add an obj to be tracked as a global. maybe_return_annotation: List[str] = [""]
We call this for names that reference objects external to the def add_global(name_hint: str, obj: Any):
Graph, like functions or types. """Add an obj to be tracked as a global.
Returns: the global name that should be used to reference 'obj' in generated source. We call this for names that reference objects external to the
""" Graph, like functions or types.
if (
_is_from_torch(obj) and obj != torch.device
): # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
return _get_qualified_name(obj)
# normalize the name hint to get a proper identifier Returns: the global name that should be used to reference 'obj' in generated source.
global_name = namespace.create_name(name_hint, obj) """
if (
_is_from_torch(obj) and obj != torch.device
): # to support registering torch.device
# HACK: workaround for how torch custom ops are registered. We
# can't import them like normal modules so they must retain their
# fully qualified name.
return _get_qualified_name(obj)
if global_name in globals_: # normalize the name hint to get a proper identifier
assert globals_[global_name] is obj global_name = namespace.create_name(name_hint, obj)
if global_name in globals_:
assert globals_[global_name] is obj
return global_name
globals_[global_name] = obj
return global_name return global_name
globals_[global_name] = obj
return global_name
# set _custom_builtins here so that we needn't import colossalai in forward # set _custom_builtins here so that we needn't import colossalai in forward
_custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai) _custom_builtins["colossalai"] = _CustomBuiltin(
"import colossalai", colossalai
)
# Pre-fill the globals table with registered builtins. # Pre-fill the globals table with registered builtins.
for name, (_, obj) in _custom_builtins.items(): for name, (_, obj) in _custom_builtins.items():
add_global(name, obj) add_global(name, obj)
def type_repr(o: Any): def type_repr(o: Any):
if o == (): if o == ():
# Empty tuple is used for empty tuple type annotation Tuple[()] # Empty tuple is used for empty tuple type annotation Tuple[()]
return "()" return "()"
typename = _type_repr(o) typename = _type_repr(o)
if hasattr(o, "__origin__"): if hasattr(o, "__origin__"):
# This is a generic type, e.g. typing.List[torch.Tensor] # This is a generic type, e.g. typing.List[torch.Tensor]
origin_type = _origin_type_map.get(o.__origin__, o.__origin__) origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
origin_typename = add_global(_type_repr(origin_type), origin_type) origin_typename = add_global(_type_repr(origin_type), origin_type)
if hasattr(o, "__args__"): if hasattr(o, "__args__"):
# Assign global names for each of the inner type variables. # Assign global names for each of the inner type variables.
args = [type_repr(arg) for arg in o.__args__] args = [type_repr(arg) for arg in o.__args__]
if len(args) == 0: if len(args) == 0:
# Bare type, such as `typing.Tuple` with no subscript
# This code-path used in Python < 3.9
return origin_typename
return f'{origin_typename}[{",".join(args)}]'
else:
# Bare type, such as `typing.Tuple` with no subscript # Bare type, such as `typing.Tuple` with no subscript
# This code-path used in Python < 3.9 # This code-path used in Python 3.9+
return origin_typename return origin_typename
return f'{origin_typename}[{",".join(args)}]' # Common case: this is a regular module name like 'foo.bar.baz'
return add_global(typename, o)
def _format_args(
args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
) -> str:
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}"
return repr(arg)
args_s = ", ".join(_get_repr(a) for a in args)
kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
if args_s and kwargs_s:
return f"{args_s}, {kwargs_s}"
return args_s or kwargs_s
# Run through reverse nodes and record the first instance of a use
# of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to free unused
# values
node_to_last_use: Dict[Node, Node] = {}
user_to_last_uses: Dict[Node, List[Node]] = {}
def register_last_uses(n: Node, user: Node):
if n not in node_to_last_use:
node_to_last_use[n] = user
user_to_last_uses.setdefault(user, []).append(n)
for node in reversed(nodes):
map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
delete_free_var_from_last_use(user_to_last_uses)
# NOTE: we add a variable to distinguish body and ckpt_func
def delete_unused_values(user: Node, body, to_keep=[]):
"""
Delete values after their last use. This ensures that values that are
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
if user.op == "placeholder":
return
if user.op == "output":
body.append("\n")
return
nodes_to_delete = user_to_last_uses.get(user, [])
nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep]
if len(nodes_to_delete):
to_delete_str = " = ".join(
[repr(n) for n in nodes_to_delete] + ["None"]
)
body.append(f"; {to_delete_str}\n")
else: else:
# Bare type, such as `typing.Tuple` with no subscript body.append("\n")
# This code-path used in Python 3.9+
return origin_typename
# Common case: this is a regular module name like 'foo.bar.baz' # NOTE: we add a variable to distinguish body and ckpt_func
return add_global(typename, o) def emit_node(node: Node, body):
maybe_type_annotation = (
def _format_args( "" if node.type is None else f" : {type_repr(node.type)}"
args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
) -> str:
def _get_repr(arg):
# Handle NamedTuples (if it has `_fields`) via add_global.
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
qualified_name = _get_qualified_name(type(arg))
global_name = add_global(qualified_name, type(arg))
return f"{global_name}{repr(tuple(arg))}"
return repr(arg)
args_s = ", ".join(_get_repr(a) for a in args)
kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
if args_s and kwargs_s:
return f"{args_s}, {kwargs_s}"
return args_s or kwargs_s
# Run through reverse nodes and record the first instance of a use
# of a given node. This represents the *last* use of the node in the
# execution order of the program, which we will use to free unused
# values
node_to_last_use: Dict[Node, Node] = {}
user_to_last_uses: Dict[Node, List[Node]] = {}
def register_last_uses(n: Node, user: Node):
if n not in node_to_last_use:
node_to_last_use[n] = user
user_to_last_uses.setdefault(user, []).append(n)
for node in reversed(nodes):
map_arg(node.args, lambda n: register_last_uses(n, node))
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
delete_free_var_from_last_use(user_to_last_uses)
# NOTE: we add a variable to distinguish body and ckpt_func
def delete_unused_values(user: Node, body, to_keep=[]):
"""
Delete values after their last use. This ensures that values that are
not used in the remainder of the code are freed and the memory usage
of the code is optimal.
"""
if user.op == "placeholder":
return
if user.op == "output":
body.append("\n")
return
nodes_to_delete = user_to_last_uses.get(user, [])
nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep]
if len(nodes_to_delete):
to_delete_str = " = ".join(
[repr(n) for n in nodes_to_delete] + ["None"]
) )
body.append(f"; {to_delete_str}\n") if node.op == "placeholder":
else: assert isinstance(node.target, str)
body.append("\n") maybe_default_arg = (
"" if not node.args else f" = {repr(node.args[0])}"
)
free_vars.append(
f"{node.target}{maybe_type_annotation}{maybe_default_arg}"
)
raw_name = node.target.replace("*", "")
if raw_name != repr(node):
body.append(f"{repr(node)} = {raw_name}\n")
return
elif node.op == "call_method":
assert isinstance(node.target, str)
body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
f"({_format_args(node.args[1:], node.kwargs)})"
)
return
elif node.op == "call_function":
assert callable(node.target)
# pretty print operators
if (
node.target.__module__ == "_operator"
and node.target.__name__ in magic_methods
):
assert isinstance(node.args, tuple)
body.append(
f"{repr(node)}{maybe_type_annotation} = "
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
)
return
# NOTE: we add a variable to distinguish body and ckpt_func # pretty print inplace operators; required for jit.script to work properly
def emit_node(node: Node, body): # not currently supported in normal FX graphs, but generated by torchdynamo
maybe_type_annotation = ( if (
"" if node.type is None else f" : {type_repr(node.type)}" node.target.__module__ == "_operator"
) and node.target.__name__ in inplace_methods
if node.op == "placeholder": ):
assert isinstance(node.target, str) body.append(
maybe_default_arg = "" if not node.args else f" = {repr(node.args[0])}" f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
free_vars.append( f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
f"{node.target}{maybe_type_annotation}{maybe_default_arg}" )
) return
raw_name = node.target.replace("*", "")
if raw_name != repr(node): qualified_name = _get_qualified_name(node.target)
body.append(f"{repr(node)} = {raw_name}\n") global_name = add_global(qualified_name, node.target)
return # special case for getattr: node.args could be 2-argument or 3-argument
elif node.op == "call_method": # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
assert isinstance(node.target, str) if (
body.append( global_name == "getattr"
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}" and isinstance(node.args, tuple)
f"({_format_args(node.args[1:], node.kwargs)})" and isinstance(node.args[1], str)
) and node.args[1].isidentifier()
return and len(node.args) == 2
elif node.op == "call_function": ):
assert callable(node.target) body.append(
# pretty print operators f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
if ( )
node.target.__module__ == "_operator" return
and node.target.__name__ in magic_methods body.append(
): f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
assert isinstance(node.args, tuple) )
if node.meta.get("is_wrapped", False):
wrapped_fns.setdefault(global_name)
return
elif node.op == "call_module":
assert isinstance(node.target, str)
body.append( body.append(
f"{repr(node)}{maybe_type_annotation} = " f"{repr(node)}{maybe_type_annotation} = "
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}" f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
) )
return return
elif node.op == "get_attr":
# pretty print inplace operators; required for jit.script to work properly assert isinstance(node.target, str)
# not currently supported in normal FX graphs, but generated by torchdynamo
if (
node.target.__module__ == "_operator"
and node.target.__name__ in inplace_methods
):
body.append( body.append(
f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; " f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}"
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
) )
return return
elif node.op == "output":
qualified_name = _get_qualified_name(node.target) if node.type is not None:
global_name = add_global(qualified_name, node.target) maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
# special case for getattr: node.args could be 2-argument or 3-argument body.append(self.generate_output(node.args[0]))
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
if (
global_name == "getattr"
and isinstance(node.args, tuple)
and isinstance(node.args[1], str)
and node.args[1].isidentifier()
and len(node.args) == 2
):
body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
)
return return
body.append( raise NotImplementedError(f"node: {node.op} {node.target}")
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
# Modified for activation checkpointing
ckpt_func = []
# if any node has a list of labels for activation_checkpoint, we
# will use nested type of activation checkpoint codegen
emit_code_with_chunk(
body,
nodes,
emit_node,
delete_unused_values,
self.search_chunk,
self.chunk_infos,
)
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
# have been emitted. To continue to have valid Python code, emit a
# single pass statement
body.append("pass\n")
if len(wrapped_fns) > 0:
wrap_name = add_global("wrap", torch.fx.wrap)
wrap_stmts = "\n".join(
[f'{wrap_name}("{name}")' for name in wrapped_fns]
) )
if node.meta.get("is_wrapped", False): else:
wrapped_fns.setdefault(global_name) wrap_stmts = ""
return
elif node.op == "call_module":
assert isinstance(node.target, str)
body.append(
f"{repr(node)}{maybe_type_annotation} = "
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
)
return
elif node.op == "get_attr":
assert isinstance(node.target, str)
body.append(
f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}"
)
return
elif node.op == "output":
if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
body.append(self.generate_output(node.args[0]))
return
raise NotImplementedError(f"node: {node.op} {node.target}")
# Modified for activation checkpointing if self._body_transformer:
ckpt_func = [] body = self._body_transformer(body)
# if any node has a list of labels for activation_checkpoint, we for name, value in self.additional_globals():
# will use nested type of activation checkpoint codegen add_global(name, value)
emit_code_with_chunk(
body,
nodes,
emit_node,
delete_unused_values,
self.search_chunk,
self.chunk_infos,
)
if len(body) == 0: # as we need colossalai.utils.checkpoint, we need to import colossalai
# If the Graph has no non-placeholder nodes, no lines for the body # in forward function
# have been emitted. To continue to have valid Python code, emit a prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
# single pass statement prologue = "".join(ckpt_func) + prologue
body.append("pass\n") prologue = prologue
if len(wrapped_fns) > 0: code = "".join(body)
wrap_name = add_global("wrap", torch.fx.wrap) code = "\n".join(" " + line for line in code.split("\n"))
wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns]) fn_code = f"""
else: {wrap_stmts}
wrap_stmts = ""
if self._body_transformer: {prologue}
body = self._body_transformer(body) {code}"""
# print(fn_code)
for name, value in self.additional_globals(): return PythonCode(fn_code, globals_)
add_global(name, value)
# as we need colossalai.utils.checkpoint, we need to import colossalai
# in forward function
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
prologue = "".join(ckpt_func) + prologue
prologue = prologue
code = "".join(body)
code = "\n".join(" " + line for line in code.split("\n"))
fn_code = f"""
{wrap_stmts}
{prologue}
{code}"""
# print(fn_code)
return PythonCode(fn_code, globals_)