mirror of https://github.com/hpcaitech/ColossalAI
[autochunk] add autochunk feature
commit
93f62dd152
|
@ -0,0 +1,593 @@
|
|||
from typing import Any, Dict, Iterable, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||
|
||||
if CODEGEN_AVAILABLE:
|
||||
from torch.fx.graph import (
|
||||
CodeGen,
|
||||
PythonCode,
|
||||
_custom_builtins,
|
||||
_CustomBuiltin,
|
||||
_format_target,
|
||||
_is_from_torch,
|
||||
_Namespace,
|
||||
_origin_type_map,
|
||||
inplace_methods,
|
||||
magic_methods,
|
||||
)
|
||||
|
||||
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
|
||||
|
||||
from .search_chunk import SearchChunk
|
||||
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape
|
||||
|
||||
|
||||
def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str:
|
||||
"""
|
||||
Generate chunk slice string, eg. [:, :, chunk_idx_name:chunk_idx_name + chunk_size, :]
|
||||
|
||||
Args:
|
||||
chunk_dim (int)
|
||||
chunk_indice_name (str): chunk indice name
|
||||
shape (List): node shape
|
||||
|
||||
Returns:
|
||||
new_shape (str): return slice
|
||||
"""
|
||||
new_shape = "["
|
||||
for idx, _ in enumerate(shape):
|
||||
if idx == chunk_dim:
|
||||
new_shape += "%s:%s + chunk_size" % (chunk_indice_name, chunk_indice_name)
|
||||
else:
|
||||
new_shape += ":"
|
||||
new_shape += ", "
|
||||
new_shape = new_shape[:-2] + "]"
|
||||
return new_shape
|
||||
|
||||
|
||||
def _gen_loop_start(
|
||||
chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim: int, chunk_size=2
|
||||
) -> str:
|
||||
"""
|
||||
Generate chunk loop start
|
||||
|
||||
eg. chunk_result = torch.empty([100, 100], dtype=input_node.dtype, device=input_node.device)
|
||||
chunk_size = 32
|
||||
for chunk_idx in range(0, 100, 32):
|
||||
......
|
||||
|
||||
Args:
|
||||
chunk_input (List[Node]): chunk input node
|
||||
chunk_output (Node): chunk output node
|
||||
chunk_ouput_dim (int): chunk output node chunk dim
|
||||
chunk_size (int): chunk size. Defaults to 2.
|
||||
|
||||
Returns:
|
||||
context (str): generated str
|
||||
"""
|
||||
input_node = chunk_input[0]
|
||||
out_shape = get_node_shape(chunk_output)
|
||||
out_str = str(list(out_shape))
|
||||
context = (
|
||||
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range"
|
||||
% (out_str, input_node.name, input_node.name, chunk_size)
|
||||
)
|
||||
context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim])
|
||||
return context
|
||||
|
||||
|
||||
def _gen_loop_end(
|
||||
chunk_inputs: List[Node],
|
||||
chunk_non_compute_inputs: List[Node],
|
||||
chunk_outputs: Node,
|
||||
chunk_outputs_dim: int,
|
||||
node_list: List[Node],
|
||||
) -> str:
|
||||
"""
|
||||
Generate chunk loop end
|
||||
|
||||
eg. chunk_result[chunk_idx:chunk_idx + chunk_size] = output_node
|
||||
output_node = chunk_result; xx = None; xx = None
|
||||
|
||||
Args:
|
||||
chunk_inputs (List[Node]): chunk input node
|
||||
chunk_non_compute_inputs (List[Node]): input node without chunk
|
||||
chunk_outputs (Node): chunk output node
|
||||
chunk_outputs_dim (int): chunk output node chunk dim
|
||||
node_list (List)
|
||||
|
||||
Returns:
|
||||
context (str): generated str
|
||||
"""
|
||||
chunk_outputs_name = chunk_outputs.name
|
||||
chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list)
|
||||
chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape
|
||||
chunk_slice = _gen_chunk_slice_dim(
|
||||
chunk_outputs_dim, "chunk_idx", chunk_output_shape
|
||||
)
|
||||
context = " chunk_result%s = %s; %s = None\n" % (
|
||||
chunk_slice,
|
||||
chunk_outputs_name,
|
||||
chunk_outputs_name,
|
||||
)
|
||||
context += (
|
||||
chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None"
|
||||
)
|
||||
|
||||
# determine if its the last use for chunk input
|
||||
for chunk_input in chunk_inputs + chunk_non_compute_inputs:
|
||||
if all(
|
||||
[
|
||||
find_idx_by_name(user.name, node_list) <= chunk_outputs_idx
|
||||
for user in chunk_input.users.keys()
|
||||
]
|
||||
):
|
||||
context += "; %s = None" % chunk_input.name
|
||||
|
||||
context += "\n"
|
||||
return context
|
||||
|
||||
|
||||
def _replace_name(context: str, name_from: str, name_to: str) -> str:
|
||||
"""
|
||||
replace node name
|
||||
"""
|
||||
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")")]
|
||||
for p in patterns:
|
||||
source = p[0] + name_from + p[1]
|
||||
target = p[0] + name_to + p[1]
|
||||
if source in context:
|
||||
context = context.replace(source, target)
|
||||
return context
|
||||
|
||||
|
||||
def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict) -> str:
|
||||
"""
|
||||
replace reshape size, some may have changed due to chunk
|
||||
"""
|
||||
if node_name not in reshape_size_dict:
|
||||
return context
|
||||
for size_name, size_value in reshape_size_dict[node_name].items():
|
||||
context = context.replace(size_name, size_value)
|
||||
return context
|
||||
|
||||
|
||||
def _replace_ones_like(
|
||||
search_chunk: SearchChunk,
|
||||
chunk_infos: List[Dict],
|
||||
region_idx: int,
|
||||
node_idx: int,
|
||||
node: Node,
|
||||
body: List[str],
|
||||
) -> List[str]:
|
||||
"""
|
||||
add chunk slice for new tensor op such as ones like
|
||||
"""
|
||||
if "ones_like" in node.name:
|
||||
meta_node = search_chunk.trace_indice.node_list[node_idx]
|
||||
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
|
||||
if get_node_shape(meta_node)[chunk_dim] != 1:
|
||||
source_node = meta_node.args[0].args[0]
|
||||
if (
|
||||
source_node not in chunk_infos[region_idx]["node_chunk_dim"]
|
||||
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"]
|
||||
is None
|
||||
):
|
||||
chunk_slice = _gen_chunk_slice_dim(
|
||||
chunk_dim, "chunk_idx", get_node_shape(node)
|
||||
)
|
||||
body[-1] = _replace_name(
|
||||
body[-1], node.args[0].name, node.args[0].name + chunk_slice
|
||||
)
|
||||
return body
|
||||
|
||||
|
||||
def _replace_input_node(
|
||||
chunk_inputs: List[Node],
|
||||
region_idx: int,
|
||||
chunk_inputs_dim: Dict,
|
||||
node_idx: int,
|
||||
body: List[str],
|
||||
) -> List[str]:
|
||||
"""
|
||||
add chunk slice for input nodes
|
||||
"""
|
||||
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
|
||||
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
|
||||
if idx == node_idx:
|
||||
chunk_slice = _gen_chunk_slice_dim(
|
||||
dim[0], "chunk_idx", get_node_shape(input_node)
|
||||
)
|
||||
body[-1] = _replace_name(
|
||||
body[-1], input_node.name, input_node.name + chunk_slice
|
||||
)
|
||||
return body
|
||||
|
||||
|
||||
def emit_code_with_chunk(
|
||||
body: List[str],
|
||||
nodes: Iterable[Node],
|
||||
emit_node_func,
|
||||
delete_unused_value_func,
|
||||
search_chunk: SearchChunk,
|
||||
chunk_infos: List,
|
||||
):
|
||||
"""
|
||||
Emit code with chunk according to chunk_infos.
|
||||
|
||||
It will generate a for loop in chunk regions, and
|
||||
replace inputs and outputs of regions with chunked variables.
|
||||
|
||||
Args:
|
||||
body: forward code
|
||||
nodes: graph.nodes
|
||||
emit_node_func: function to emit node
|
||||
delete_unused_value_func: function to remove the unused value
|
||||
search_chunk: the class to search all chunks
|
||||
chunk_infos: store all information about all chunks.
|
||||
"""
|
||||
node_list = list(nodes)
|
||||
|
||||
# chunk region
|
||||
chunk_starts = [i["region"][0] for i in chunk_infos]
|
||||
chunk_ends = [i["region"][1] for i in chunk_infos]
|
||||
|
||||
# chunk inputs
|
||||
chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
|
||||
chunk_inputs_non_chunk = [
|
||||
i["inputs_non_chunk"] for i in chunk_infos
|
||||
] # input without chunk
|
||||
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
|
||||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [
|
||||
j.name for i in chunk_inputs_non_chunk for j in i
|
||||
]
|
||||
|
||||
# chunk outputs
|
||||
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
|
||||
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
|
||||
|
||||
node_list = search_chunk.reorder_graph.reorder_node_list(node_list)
|
||||
node_idx = 0
|
||||
region_idx = 0
|
||||
within_chunk_region = False
|
||||
|
||||
while node_idx < len(node_list):
|
||||
node = node_list[node_idx]
|
||||
|
||||
# if is chunk start, generate for loop start
|
||||
if node_idx in chunk_starts:
|
||||
within_chunk_region = True
|
||||
region_idx = chunk_starts.index(node_idx)
|
||||
body.append(
|
||||
_gen_loop_start(
|
||||
chunk_inputs[region_idx],
|
||||
chunk_outputs[region_idx],
|
||||
chunk_outputs_dim[region_idx],
|
||||
chunk_infos[region_idx]["chunk_size"],
|
||||
)
|
||||
)
|
||||
|
||||
if within_chunk_region:
|
||||
emit_node_func(node, body)
|
||||
# replace input var with chunk var
|
||||
body = _replace_input_node(
|
||||
chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body
|
||||
)
|
||||
# ones like
|
||||
body = _replace_ones_like(
|
||||
search_chunk, chunk_infos, region_idx, node_idx, node, body
|
||||
)
|
||||
# reassgin reshape size
|
||||
body[-1] = _replace_reshape_size(
|
||||
body[-1], node.name, chunk_infos[region_idx]["reshape_size"]
|
||||
)
|
||||
body[-1] = " " + body[-1]
|
||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||
else:
|
||||
emit_node_func(node, body)
|
||||
if node_idx not in chunk_inputs:
|
||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||
|
||||
# generate chunk region end
|
||||
if node_idx in chunk_ends:
|
||||
body.append(
|
||||
_gen_loop_end(
|
||||
chunk_inputs[region_idx],
|
||||
chunk_inputs_non_chunk[region_idx],
|
||||
chunk_outputs[region_idx],
|
||||
chunk_outputs_dim[region_idx],
|
||||
node_list,
|
||||
)
|
||||
)
|
||||
within_chunk_region = False
|
||||
|
||||
node_idx += 1
|
||||
|
||||
|
||||
if CODEGEN_AVAILABLE:
|
||||
|
||||
class AutoChunkCodeGen(CodeGen):
|
||||
def __init__(self, meta_graph, max_memory=None, print_mem=False):
|
||||
super().__init__()
|
||||
self.meta_graph = meta_graph
|
||||
self.max_memory = max_memory
|
||||
self.meta_node = list(meta_graph.graph.nodes)
|
||||
# find the chunk regions
|
||||
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem)
|
||||
self.chunk_infos = self.search_chunk.search_region()
|
||||
|
||||
def _gen_python_code(
|
||||
self, nodes, root_module: str, namespace: _Namespace
|
||||
) -> PythonCode:
|
||||
free_vars: List[str] = []
|
||||
body: List[str] = []
|
||||
globals_: Dict[str, Any] = {}
|
||||
wrapped_fns: Dict[str, None] = {}
|
||||
|
||||
# Wrap string in list to pass by reference
|
||||
maybe_return_annotation: List[str] = [""]
|
||||
|
||||
def add_global(name_hint: str, obj: Any):
|
||||
"""Add an obj to be tracked as a global.
|
||||
|
||||
We call this for names that reference objects external to the
|
||||
Graph, like functions or types.
|
||||
|
||||
Returns: the global name that should be used to reference 'obj' in generated source.
|
||||
"""
|
||||
if (
|
||||
_is_from_torch(obj) and obj != torch.device
|
||||
): # to support registering torch.device
|
||||
# HACK: workaround for how torch custom ops are registered. We
|
||||
# can't import them like normal modules so they must retain their
|
||||
# fully qualified name.
|
||||
return _get_qualified_name(obj)
|
||||
|
||||
# normalize the name hint to get a proper identifier
|
||||
global_name = namespace.create_name(name_hint, obj)
|
||||
|
||||
if global_name in globals_:
|
||||
assert globals_[global_name] is obj
|
||||
return global_name
|
||||
globals_[global_name] = obj
|
||||
return global_name
|
||||
|
||||
# set _custom_builtins here so that we needn't import colossalai in forward
|
||||
_custom_builtins["colossalai"] = _CustomBuiltin(
|
||||
"import colossalai", colossalai
|
||||
)
|
||||
|
||||
# Pre-fill the globals table with registered builtins.
|
||||
for name, (_, obj) in _custom_builtins.items():
|
||||
add_global(name, obj)
|
||||
|
||||
def type_repr(o: Any):
|
||||
if o == ():
|
||||
# Empty tuple is used for empty tuple type annotation Tuple[()]
|
||||
return "()"
|
||||
|
||||
typename = _type_repr(o)
|
||||
|
||||
if hasattr(o, "__origin__"):
|
||||
# This is a generic type, e.g. typing.List[torch.Tensor]
|
||||
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
|
||||
origin_typename = add_global(_type_repr(origin_type), origin_type)
|
||||
|
||||
if hasattr(o, "__args__"):
|
||||
# Assign global names for each of the inner type variables.
|
||||
args = [type_repr(arg) for arg in o.__args__]
|
||||
|
||||
if len(args) == 0:
|
||||
# Bare type, such as `typing.Tuple` with no subscript
|
||||
# This code-path used in Python < 3.9
|
||||
return origin_typename
|
||||
|
||||
return f'{origin_typename}[{",".join(args)}]'
|
||||
else:
|
||||
# Bare type, such as `typing.Tuple` with no subscript
|
||||
# This code-path used in Python 3.9+
|
||||
return origin_typename
|
||||
|
||||
# Common case: this is a regular module name like 'foo.bar.baz'
|
||||
return add_global(typename, o)
|
||||
|
||||
def _format_args(
|
||||
args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
|
||||
) -> str:
|
||||
def _get_repr(arg):
|
||||
# Handle NamedTuples (if it has `_fields`) via add_global.
|
||||
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
|
||||
qualified_name = _get_qualified_name(type(arg))
|
||||
global_name = add_global(qualified_name, type(arg))
|
||||
return f"{global_name}{repr(tuple(arg))}"
|
||||
return repr(arg)
|
||||
|
||||
args_s = ", ".join(_get_repr(a) for a in args)
|
||||
kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
|
||||
if args_s and kwargs_s:
|
||||
return f"{args_s}, {kwargs_s}"
|
||||
return args_s or kwargs_s
|
||||
|
||||
# Run through reverse nodes and record the first instance of a use
|
||||
# of a given node. This represents the *last* use of the node in the
|
||||
# execution order of the program, which we will use to free unused
|
||||
# values
|
||||
node_to_last_use: Dict[Node, Node] = {}
|
||||
user_to_last_uses: Dict[Node, List[Node]] = {}
|
||||
|
||||
def register_last_uses(n: Node, user: Node):
|
||||
if n not in node_to_last_use:
|
||||
node_to_last_use[n] = user
|
||||
user_to_last_uses.setdefault(user, []).append(n)
|
||||
|
||||
for node in reversed(nodes):
|
||||
map_arg(node.args, lambda n: register_last_uses(n, node))
|
||||
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||
|
||||
delete_free_var_from_last_use(user_to_last_uses)
|
||||
|
||||
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||
def delete_unused_values(user: Node, body, to_keep=[]):
|
||||
"""
|
||||
Delete values after their last use. This ensures that values that are
|
||||
not used in the remainder of the code are freed and the memory usage
|
||||
of the code is optimal.
|
||||
"""
|
||||
if user.op == "placeholder":
|
||||
return
|
||||
if user.op == "output":
|
||||
body.append("\n")
|
||||
return
|
||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep]
|
||||
if len(nodes_to_delete):
|
||||
to_delete_str = " = ".join(
|
||||
[repr(n) for n in nodes_to_delete] + ["None"]
|
||||
)
|
||||
body.append(f"; {to_delete_str}\n")
|
||||
else:
|
||||
body.append("\n")
|
||||
|
||||
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||
def emit_node(node: Node, body):
|
||||
maybe_type_annotation = (
|
||||
"" if node.type is None else f" : {type_repr(node.type)}"
|
||||
)
|
||||
if node.op == "placeholder":
|
||||
assert isinstance(node.target, str)
|
||||
maybe_default_arg = (
|
||||
"" if not node.args else f" = {repr(node.args[0])}"
|
||||
)
|
||||
free_vars.append(
|
||||
f"{node.target}{maybe_type_annotation}{maybe_default_arg}"
|
||||
)
|
||||
raw_name = node.target.replace("*", "")
|
||||
if raw_name != repr(node):
|
||||
body.append(f"{repr(node)} = {raw_name}\n")
|
||||
return
|
||||
elif node.op == "call_method":
|
||||
assert isinstance(node.target, str)
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
|
||||
f"({_format_args(node.args[1:], node.kwargs)})"
|
||||
)
|
||||
return
|
||||
elif node.op == "call_function":
|
||||
assert callable(node.target)
|
||||
# pretty print operators
|
||||
if (
|
||||
node.target.__module__ == "_operator"
|
||||
and node.target.__name__ in magic_methods
|
||||
):
|
||||
assert isinstance(node.args, tuple)
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = "
|
||||
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
|
||||
)
|
||||
return
|
||||
|
||||
# pretty print inplace operators; required for jit.script to work properly
|
||||
# not currently supported in normal FX graphs, but generated by torchdynamo
|
||||
if (
|
||||
node.target.__module__ == "_operator"
|
||||
and node.target.__name__ in inplace_methods
|
||||
):
|
||||
body.append(
|
||||
f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
|
||||
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
|
||||
)
|
||||
return
|
||||
|
||||
qualified_name = _get_qualified_name(node.target)
|
||||
global_name = add_global(qualified_name, node.target)
|
||||
# special case for getattr: node.args could be 2-argument or 3-argument
|
||||
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
|
||||
if (
|
||||
global_name == "getattr"
|
||||
and isinstance(node.args, tuple)
|
||||
and isinstance(node.args[1], str)
|
||||
and node.args[1].isidentifier()
|
||||
and len(node.args) == 2
|
||||
):
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
|
||||
)
|
||||
return
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
|
||||
)
|
||||
if node.meta.get("is_wrapped", False):
|
||||
wrapped_fns.setdefault(global_name)
|
||||
return
|
||||
elif node.op == "call_module":
|
||||
assert isinstance(node.target, str)
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = "
|
||||
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
|
||||
)
|
||||
return
|
||||
elif node.op == "get_attr":
|
||||
assert isinstance(node.target, str)
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}"
|
||||
)
|
||||
return
|
||||
elif node.op == "output":
|
||||
if node.type is not None:
|
||||
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
|
||||
body.append(self.generate_output(node.args[0]))
|
||||
return
|
||||
raise NotImplementedError(f"node: {node.op} {node.target}")
|
||||
|
||||
# Modified for activation checkpointing
|
||||
ckpt_func = []
|
||||
|
||||
# if any node has a list of labels for activation_checkpoint, we
|
||||
# will use nested type of activation checkpoint codegen
|
||||
emit_code_with_chunk(
|
||||
body,
|
||||
nodes,
|
||||
emit_node,
|
||||
delete_unused_values,
|
||||
self.search_chunk,
|
||||
self.chunk_infos,
|
||||
)
|
||||
|
||||
if len(body) == 0:
|
||||
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||
# have been emitted. To continue to have valid Python code, emit a
|
||||
# single pass statement
|
||||
body.append("pass\n")
|
||||
|
||||
if len(wrapped_fns) > 0:
|
||||
wrap_name = add_global("wrap", torch.fx.wrap)
|
||||
wrap_stmts = "\n".join(
|
||||
[f'{wrap_name}("{name}")' for name in wrapped_fns]
|
||||
)
|
||||
else:
|
||||
wrap_stmts = ""
|
||||
|
||||
if self._body_transformer:
|
||||
body = self._body_transformer(body)
|
||||
|
||||
for name, value in self.additional_globals():
|
||||
add_global(name, value)
|
||||
|
||||
# as we need colossalai.utils.checkpoint, we need to import colossalai
|
||||
# in forward function
|
||||
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
|
||||
prologue = "".join(ckpt_func) + prologue
|
||||
prologue = prologue
|
||||
|
||||
code = "".join(body)
|
||||
code = "\n".join(" " + line for line in code.split("\n"))
|
||||
fn_code = f"""
|
||||
{wrap_stmts}
|
||||
|
||||
{prologue}
|
||||
{code}"""
|
||||
# print(fn_code)
|
||||
return PythonCode(fn_code, globals_)
|
|
@ -0,0 +1,328 @@
|
|||
import copy
|
||||
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.fx.node import Node, map_arg
|
||||
|
||||
from colossalai.fx.profiler import activation_size, parameter_size
|
||||
|
||||
from .utils import (
|
||||
delete_free_var_from_last_use,
|
||||
find_idx_by_name,
|
||||
get_node_shape,
|
||||
is_non_compute_node_except_placeholder,
|
||||
)
|
||||
|
||||
|
||||
class EstimateMemory(object):
|
||||
"""
|
||||
Estimate memory with chunk
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_meta_node_size(self, x):
|
||||
x = x.meta["tensor_meta"]
|
||||
x = x.numel * torch.tensor([], dtype=x.dtype).element_size()
|
||||
return x
|
||||
|
||||
def _get_output_node(self, n):
|
||||
out_size = activation_size(n.meta["fwd_out"])
|
||||
out_node = [n.name] if out_size > 0 else []
|
||||
return out_size, out_node
|
||||
|
||||
def _get_output_node_size(self, n):
|
||||
return self._get_output_node(n)[0]
|
||||
|
||||
def _add_active_node(self, n, active_list):
|
||||
new_active = self._get_output_node(n)[1]
|
||||
if n.op == "placeholder":
|
||||
new_active.append(n.name)
|
||||
for i in new_active:
|
||||
if i not in active_list:
|
||||
active_list.append(i)
|
||||
|
||||
def _get_delete_node(self, user, user_to_last_uses, to_keep=None):
|
||||
delete_size = 0
|
||||
delete_node = []
|
||||
if user.op not in ("output",):
|
||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
if to_keep is not None:
|
||||
keep_list = []
|
||||
for n in nodes_to_delete:
|
||||
if n.name in to_keep:
|
||||
keep_list.append(n)
|
||||
for n in keep_list:
|
||||
if n in nodes_to_delete:
|
||||
nodes_to_delete.remove(n)
|
||||
if len(nodes_to_delete):
|
||||
out_node = [self._get_output_node(i) for i in nodes_to_delete]
|
||||
delete_size = sum([i[0] for i in out_node])
|
||||
for i in range(len(out_node)):
|
||||
if out_node[i][0] > 0:
|
||||
delete_node.append(out_node[i][1][0])
|
||||
elif nodes_to_delete[i].op == "placeholder":
|
||||
delete_node.append(nodes_to_delete[i].name)
|
||||
# elif any(j in nodes_to_delete[i].name for j in ['transpose', 'permute', 'view']):
|
||||
# delete_node.append(nodes_to_delete[i].name)
|
||||
return delete_size, delete_node
|
||||
|
||||
def _get_delete_node_size(self, user, user_to_last_uses, to_keep):
|
||||
return self._get_delete_node(user, user_to_last_uses, to_keep)[0]
|
||||
|
||||
def _remove_deactive_node(self, user, user_to_last_uses, active_list):
|
||||
delete_node = self._get_delete_node(user, user_to_last_uses)[1]
|
||||
for i in delete_node:
|
||||
if i in active_list:
|
||||
active_list.remove(i)
|
||||
|
||||
def _get_chunk_inputs_size(
|
||||
self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx
|
||||
):
|
||||
nodes_to_delete = []
|
||||
for chunk_input in chunk_inputs + chunk_inputs_non_chunk:
|
||||
chunk_input_users = chunk_input.users.keys()
|
||||
chunk_input_users_idx = [
|
||||
find_idx_by_name(i.name, node_list) for i in chunk_input_users
|
||||
]
|
||||
if all(i <= chunk_end_idx for i in chunk_input_users_idx):
|
||||
if chunk_input not in nodes_to_delete:
|
||||
nodes_to_delete.append(chunk_input)
|
||||
out_node = [self._get_output_node(i) for i in nodes_to_delete]
|
||||
delete_size = sum([i[0] for i in out_node])
|
||||
return delete_size
|
||||
|
||||
def _get_last_usr(self, nodes):
|
||||
node_to_last_use: Dict[Node, Node] = {}
|
||||
user_to_last_uses: Dict[Node, List[Node]] = {}
|
||||
|
||||
def register_last_uses(n: Node, user: Node):
|
||||
if n not in node_to_last_use:
|
||||
node_to_last_use[n] = user
|
||||
user_to_last_uses.setdefault(user, []).append(n)
|
||||
|
||||
for node in reversed(nodes):
|
||||
map_arg(node.args, lambda n: register_last_uses(n, node))
|
||||
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||
return user_to_last_uses
|
||||
|
||||
def _get_contiguous_memory(self, node, not_contiguous_list, delete=False):
|
||||
mem = 0
|
||||
not_contiguous_ops = ["permute"]
|
||||
inherit_contiguous_ops = ["transpose", "view"]
|
||||
|
||||
if node.op == "call_function" and any(
|
||||
n in node.name for n in ["matmul", "reshape"]
|
||||
):
|
||||
for n in node.args:
|
||||
if n in not_contiguous_list:
|
||||
# matmul won't change origin tensor, but create a tmp copy
|
||||
mem += self._get_output_node_size(n)
|
||||
elif node.op == "call_module":
|
||||
for n in node.args:
|
||||
if n in not_contiguous_list:
|
||||
# module will just make origin tensor to contiguous
|
||||
if delete:
|
||||
not_contiguous_list.remove(n)
|
||||
elif node.op == "call_method" and any(
|
||||
i in node.name for i in not_contiguous_ops
|
||||
):
|
||||
if node not in not_contiguous_list:
|
||||
not_contiguous_list.append(node)
|
||||
return mem
|
||||
|
||||
def _get_chunk_ratio(self, node, chunk_node_dim, chunk_size):
|
||||
if node not in chunk_node_dim:
|
||||
return 1.0
|
||||
node_shape = get_node_shape(node)
|
||||
chunk_dim = chunk_node_dim[node]["chunk_dim"]
|
||||
if chunk_dim is None:
|
||||
return 1.0
|
||||
else:
|
||||
return float(chunk_size) / node_shape[chunk_dim]
|
||||
|
||||
def _get_chunk_delete_node_size(
|
||||
self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names
|
||||
):
|
||||
# if any(j in user.name for j in ['transpose', 'permute', 'view']):
|
||||
# return 0
|
||||
if user.op in ("placeholder", "output"):
|
||||
return 0
|
||||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
delete_size = 0
|
||||
for n in nodes_to_delete:
|
||||
if n.name in chunk_inputs_names:
|
||||
continue
|
||||
delete_size += self._get_output_node_size(n) * chunk_ratio
|
||||
return delete_size
|
||||
|
||||
def _print_mem_log(self, log, nodes, title=None):
|
||||
if title:
|
||||
print(title)
|
||||
for idx, (l, n) in enumerate(zip(log, nodes)):
|
||||
print("%s:%.2f \t" % (n.name, l), end="")
|
||||
if (idx + 1) % 3 == 0:
|
||||
print("")
|
||||
print("\n")
|
||||
|
||||
def _print_compute_op_mem_log(self, log, nodes, title=None):
|
||||
if title:
|
||||
print(title)
|
||||
for idx, (l, n) in enumerate(zip(log, nodes)):
|
||||
if n.op in ["placeholder", "get_attr", "output"]:
|
||||
continue
|
||||
if any(i in n.name for i in ["getitem", "getattr"]):
|
||||
continue
|
||||
print("%s:%.2f \t" % (n.name, l), end="")
|
||||
if (idx + 1) % 3 == 0:
|
||||
print("")
|
||||
print("\n")
|
||||
|
||||
def estimate_chunk_inference_mem(
|
||||
self,
|
||||
node_list: List,
|
||||
chunk_infos=None,
|
||||
print_mem=False,
|
||||
):
|
||||
"""
|
||||
Estimate inference memory with chunk
|
||||
|
||||
Args:
|
||||
node_list (List): _description_
|
||||
chunk_infos (Dict): Chunk information. Defaults to None.
|
||||
print_mem (bool): Wether to print peak memory of every node. Defaults to False.
|
||||
|
||||
Returns:
|
||||
act_memory_peak_log (List): peak memory of every node
|
||||
act_memory_after_node_log (List): memory after excuting every node
|
||||
active_node_list_log (List): active nodes of every node. active nodes refer to
|
||||
nodes generated but not deleted.
|
||||
"""
|
||||
act_memory = 0.0
|
||||
act_memory_peak_log = []
|
||||
act_memory_after_node_log = []
|
||||
active_node_list = []
|
||||
active_node_list_log = []
|
||||
not_contiguous_list = []
|
||||
user_to_last_uses = self._get_last_usr(node_list)
|
||||
user_to_last_uses_no_free_var = self._get_last_usr(node_list)
|
||||
delete_free_var_from_last_use(user_to_last_uses_no_free_var)
|
||||
|
||||
use_chunk = True if chunk_infos is not None else False
|
||||
chunk_within = False
|
||||
chunk_region_idx = None
|
||||
chunk_ratio = 1 # use it to estimate chunk mem
|
||||
chunk_inputs_names = []
|
||||
|
||||
if use_chunk:
|
||||
chunk_regions = [i["region"] for i in chunk_infos]
|
||||
chunk_starts = [i[0] for i in chunk_regions]
|
||||
chunk_ends = [i[1] for i in chunk_regions]
|
||||
chunk_inputs = [i["inputs"] for i in chunk_infos]
|
||||
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
|
||||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [
|
||||
j.name for i in chunk_inputs_non_chunk for j in i
|
||||
]
|
||||
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
|
||||
chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos]
|
||||
chunk_sizes = [
|
||||
i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos
|
||||
]
|
||||
|
||||
for idx, node in enumerate(node_list):
|
||||
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
|
||||
if use_chunk and idx in chunk_starts:
|
||||
chunk_within = True
|
||||
chunk_region_idx = chunk_starts.index(idx)
|
||||
act_memory += self._get_output_node_size(
|
||||
chunk_outputs[chunk_region_idx]
|
||||
) / (1024**2)
|
||||
|
||||
# determine chunk ratio for current node
|
||||
if chunk_within:
|
||||
chunk_ratio = self._get_chunk_ratio(
|
||||
node,
|
||||
chunk_node_dim[chunk_region_idx],
|
||||
chunk_sizes[chunk_region_idx],
|
||||
)
|
||||
|
||||
# if node is placeholder, just add the size of the node
|
||||
if node.op == "placeholder":
|
||||
act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024**2)
|
||||
act_memory_peak_log.append(act_memory)
|
||||
# skip output
|
||||
elif node.op == "output":
|
||||
continue
|
||||
# no change for non compute node
|
||||
elif is_non_compute_node_except_placeholder(node):
|
||||
act_memory_peak_log.append(act_memory)
|
||||
# node is a compute op
|
||||
# calculate tmp, output node and delete node memory
|
||||
else:
|
||||
# forward memory
|
||||
# TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose
|
||||
act_memory += (
|
||||
self._get_contiguous_memory(node, not_contiguous_list)
|
||||
* chunk_ratio
|
||||
/ (1024**2)
|
||||
)
|
||||
act_memory += (
|
||||
self._get_output_node_size(node) * chunk_ratio / (1024**2)
|
||||
)
|
||||
# record max act memory
|
||||
act_memory_peak_log.append(act_memory)
|
||||
# delete useless memory
|
||||
act_memory -= (
|
||||
self._get_contiguous_memory(node, not_contiguous_list, delete=True)
|
||||
* chunk_ratio
|
||||
/ (1024**2)
|
||||
)
|
||||
# delete unused vars not in chunk_input_list
|
||||
# we can't delete input nodes until chunk ends
|
||||
if chunk_within:
|
||||
act_memory -= self._get_chunk_delete_node_size(
|
||||
node,
|
||||
user_to_last_uses_no_free_var,
|
||||
chunk_ratio,
|
||||
chunk_inputs_names,
|
||||
) / (1024**2)
|
||||
else:
|
||||
act_memory -= self._get_delete_node_size(
|
||||
node, user_to_last_uses_no_free_var, chunk_inputs_names
|
||||
) / (1024**2)
|
||||
|
||||
# log active node, only effective without chunk
|
||||
self._add_active_node(node, active_node_list)
|
||||
self._remove_deactive_node(node, user_to_last_uses, active_node_list)
|
||||
|
||||
# if node in chunk end nodes, restore chunk settings
|
||||
if use_chunk and idx in chunk_ends:
|
||||
act_memory -= (
|
||||
self._get_output_node_size(node) * chunk_ratio / (1024**2)
|
||||
)
|
||||
act_memory -= self._get_chunk_inputs_size(
|
||||
chunk_inputs[chunk_region_idx],
|
||||
chunk_inputs_non_chunk[chunk_region_idx],
|
||||
node_list,
|
||||
chunk_regions[chunk_region_idx][1],
|
||||
) / (1024**2)
|
||||
chunk_within = False
|
||||
chunk_ratio = 1
|
||||
chunk_region_idx = None
|
||||
|
||||
act_memory_after_node_log.append(act_memory)
|
||||
active_node_list_log.append(copy.deepcopy(active_node_list))
|
||||
|
||||
if print_mem:
|
||||
print("with chunk" if use_chunk else "without chunk")
|
||||
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
|
||||
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
|
||||
self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak")
|
||||
# self._print_compute_op_mem_log(
|
||||
# act_memory_after_node_log, node_list, "after"
|
||||
# )
|
||||
|
||||
# param_memory = parameter_size(gm)
|
||||
# all_memory = act_memory + param_memory
|
||||
return act_memory_peak_log, act_memory_after_node_log, active_node_list_log
|
|
@ -0,0 +1,117 @@
|
|||
from .trace_indice import TraceIndice
|
||||
from .utils import find_idx_by_name
|
||||
|
||||
|
||||
class ReorderGraph(object):
|
||||
"""
|
||||
Reorder node list and indice trace list
|
||||
"""
|
||||
|
||||
def __init__(self, trace_indice: TraceIndice) -> None:
|
||||
self.trace_indice = trace_indice
|
||||
self.all_reorder_map = {
|
||||
i: i for i in range(len(self.trace_indice.indice_trace_list))
|
||||
}
|
||||
|
||||
def _get_reorder_map(self, chunk_info):
|
||||
reorder_map = {i: i for i in range(len(self.trace_indice.node_list))}
|
||||
|
||||
chunk_region_start = chunk_info["region"][0]
|
||||
chunk_region_end = chunk_info["region"][1]
|
||||
chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"]
|
||||
chunk_prepose_nodes_idx = [
|
||||
find_idx_by_name(i.name, self.trace_indice.node_list)
|
||||
for i in chunk_prepose_nodes
|
||||
]
|
||||
# put prepose nodes ahead
|
||||
for idx, n in enumerate(chunk_prepose_nodes):
|
||||
n_idx = chunk_prepose_nodes_idx[idx]
|
||||
reorder_map[n_idx] = chunk_region_start + idx
|
||||
# put other nodes after prepose nodes
|
||||
for n in self.trace_indice.node_list[chunk_region_start : chunk_region_end + 1]:
|
||||
if n in chunk_prepose_nodes:
|
||||
continue
|
||||
n_idx = find_idx_by_name(n.name, self.trace_indice.node_list)
|
||||
pos = sum([n_idx < i for i in chunk_prepose_nodes_idx])
|
||||
reorder_map[n_idx] = n_idx + pos
|
||||
|
||||
return reorder_map
|
||||
|
||||
def _reorder_chunk_info(self, chunk_info, reorder_map):
|
||||
# update chunk info
|
||||
chunk_info["region"] = (
|
||||
chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]),
|
||||
chunk_info["region"][1],
|
||||
)
|
||||
new_inputs_dim = []
|
||||
for idx, input_dim in enumerate(chunk_info["inputs_dim"]):
|
||||
new_input_dim = {}
|
||||
for k, v in input_dim.items():
|
||||
new_input_dim[reorder_map[k]] = v
|
||||
new_inputs_dim.append(new_input_dim)
|
||||
chunk_info["inputs_dim"] = new_inputs_dim
|
||||
return chunk_info
|
||||
|
||||
def _update_all_reorder_map(self, reorder_map):
|
||||
for origin_idx, map_idx in self.all_reorder_map.items():
|
||||
self.all_reorder_map[origin_idx] = reorder_map[map_idx]
|
||||
|
||||
def _reorder_self_node_list(self, reorder_map):
|
||||
new_node_list = [None for _ in range(len(self.trace_indice.node_list))]
|
||||
for old_idx, new_idx in reorder_map.items():
|
||||
new_node_list[new_idx] = self.trace_indice.node_list[old_idx]
|
||||
self.trace_indice.node_list = new_node_list
|
||||
|
||||
def _reorder_idx_trace(self, reorder_map):
|
||||
# reorder list
|
||||
new_idx_trace_list = [
|
||||
None for _ in range(len(self.trace_indice.indice_trace_list))
|
||||
]
|
||||
for old_idx, new_idx in reorder_map.items():
|
||||
new_idx_trace_list[new_idx] = self.trace_indice.indice_trace_list[old_idx]
|
||||
self.trace_indice.indice_trace_list = new_idx_trace_list
|
||||
# update compute
|
||||
for idx_trace in self.trace_indice.indice_trace_list:
|
||||
compute = idx_trace["compute"]
|
||||
for dim_compute in compute:
|
||||
for idx, i in enumerate(dim_compute):
|
||||
dim_compute[idx] = reorder_map[i]
|
||||
# update source
|
||||
for idx_trace in self.trace_indice.indice_trace_list:
|
||||
source = idx_trace["source"]
|
||||
for dim_idx, dim_source in enumerate(source):
|
||||
new_dim_source = {}
|
||||
for k, v in dim_source.items():
|
||||
new_dim_source[reorder_map[k]] = v
|
||||
source[dim_idx] = new_dim_source
|
||||
|
||||
def reorder_all(self, chunk_info):
|
||||
if chunk_info is None:
|
||||
return chunk_info
|
||||
if len(chunk_info["args"]["prepose_nodes"]) == 0:
|
||||
return chunk_info
|
||||
reorder_map = self._get_reorder_map(chunk_info)
|
||||
self._update_all_reorder_map(reorder_map)
|
||||
self._reorder_idx_trace(reorder_map)
|
||||
self._reorder_self_node_list(reorder_map)
|
||||
chunk_info = self._reorder_chunk_info(chunk_info, reorder_map)
|
||||
return chunk_info
|
||||
|
||||
def reorder_node_list(self, node_list):
|
||||
new_node_list = [None for _ in range(len(node_list))]
|
||||
for old_idx, new_idx in self.all_reorder_map.items():
|
||||
new_node_list[new_idx] = node_list[old_idx]
|
||||
return new_node_list
|
||||
|
||||
def tmp_reorder(self, node_list, chunk_info):
|
||||
if len(chunk_info["args"]["prepose_nodes"]) == 0:
|
||||
return node_list, chunk_info
|
||||
reorder_map = self._get_reorder_map(chunk_info)
|
||||
|
||||
# new tmp node list
|
||||
new_node_list = [None for _ in range(len(node_list))]
|
||||
for old_idx, new_idx in reorder_map.items():
|
||||
new_node_list[new_idx] = node_list[old_idx]
|
||||
|
||||
chunk_info = self._reorder_chunk_info(chunk_info, reorder_map)
|
||||
return new_node_list, chunk_info
|
|
@ -0,0 +1,319 @@
|
|||
import copy
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .estimate_memory import EstimateMemory
|
||||
from .reorder_graph import ReorderGraph
|
||||
from .select_chunk import SelectChunk
|
||||
from .trace_flow import TraceFlow
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import (
|
||||
get_node_shape,
|
||||
is_non_compute_node,
|
||||
is_non_compute_node_except_placeholder,
|
||||
)
|
||||
|
||||
|
||||
class SearchChunk(object):
|
||||
"""
|
||||
This is the core class for AutoChunk.
|
||||
|
||||
It defines the framework of the strategy of AutoChunk.
|
||||
Chunks will be selected one by one utill search stops.
|
||||
|
||||
The chunk search is as follows:
|
||||
1. find the peak memory node
|
||||
2. find the max chunk region according to the peak memory node
|
||||
3. find all possible chunk regions in the max chunk region
|
||||
4. find the best chunk region for current status
|
||||
5. goto 1
|
||||
|
||||
Attributes:
|
||||
gm: graph model
|
||||
print_mem (bool): print estimated memory
|
||||
trace_index: trace the flow of every dim of every node to find all free dims
|
||||
trace_flow: determine the region chunk strategy
|
||||
reorder_graph: reorder nodes to improve chunk efficiency
|
||||
estimate_memory: estimate memory with chunk
|
||||
select_chunk: select the best chunk region
|
||||
|
||||
Args:
|
||||
gm: graph model
|
||||
max_memory (int): max memory in MB
|
||||
print_mem (bool): print estimated memory
|
||||
"""
|
||||
|
||||
def __init__(self, gm, max_memory=None, print_mem=False) -> None:
|
||||
self.gm = gm
|
||||
self.print_mem = print_mem
|
||||
self.trace_indice = TraceIndice(list(gm.graph.nodes))
|
||||
self.trace_indice.trace_indice()
|
||||
self.trace_flow = TraceFlow(self.trace_indice)
|
||||
self.reorder_graph = ReorderGraph(self.trace_indice)
|
||||
self.estimate_memory = EstimateMemory()
|
||||
self.select_chunk = SelectChunk(
|
||||
self.trace_indice,
|
||||
self.estimate_memory,
|
||||
self.reorder_graph,
|
||||
max_memory=max_memory,
|
||||
)
|
||||
|
||||
def _find_peak_node(self, mem_peak):
|
||||
max_value = max(mem_peak)
|
||||
max_idx = mem_peak.index(max_value)
|
||||
return max_idx
|
||||
|
||||
def _get_free_var_idx(self) -> List:
|
||||
"""
|
||||
Get free var index
|
||||
|
||||
Returns:
|
||||
free_var_idx (List): all indexs of free vars
|
||||
"""
|
||||
free_var_idx = []
|
||||
for idx, n in enumerate(self.trace_indice.node_list):
|
||||
if n.op == "placeholder":
|
||||
free_var_idx.append(idx)
|
||||
return free_var_idx
|
||||
|
||||
def _search_max_chunk_region(
|
||||
self, active_node: List, peak_node: Node, chunk_regions: List
|
||||
) -> Tuple:
|
||||
"""
|
||||
Search max chunk region according to peak memory node
|
||||
|
||||
Chunk region starts extending from the peak node, stops where free var num is min
|
||||
|
||||
Args:
|
||||
active_node (List): active node status for every node
|
||||
peak_node (Node): peak memory node
|
||||
chunk_regions (List): chunk region infos
|
||||
|
||||
Returns:
|
||||
chunk_region_start (int)
|
||||
chunk_region_end (int)
|
||||
"""
|
||||
free_vars = self._get_free_var_idx()
|
||||
free_var_num = len(free_vars)
|
||||
active_node_num = [len(i) for i in active_node]
|
||||
min_active_node_num = min(active_node_num[free_var_num:])
|
||||
threshold = max(free_var_num, min_active_node_num)
|
||||
|
||||
# from peak_node to free_var
|
||||
inside_flag = False
|
||||
chunk_region_start = free_var_num
|
||||
for i in range(peak_node, -1, -1):
|
||||
if active_node_num[i] <= threshold:
|
||||
inside_flag = True
|
||||
if inside_flag and active_node_num[i] > threshold:
|
||||
chunk_region_start = i + 1
|
||||
break
|
||||
|
||||
# from peak_node to len-2
|
||||
inside_flag = False
|
||||
chunk_region_end = len(active_node) - 1
|
||||
for i in range(peak_node, len(active_node)):
|
||||
if active_node_num[i] <= threshold:
|
||||
inside_flag = True
|
||||
if inside_flag and active_node_num[i] > threshold:
|
||||
chunk_region_end = i
|
||||
break
|
||||
|
||||
for i in chunk_regions:
|
||||
region = i["region"]
|
||||
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
|
||||
return None
|
||||
elif (
|
||||
region[0] <= chunk_region_start <= region[1]
|
||||
and chunk_region_end > region[1]
|
||||
):
|
||||
chunk_region_start = region[1] + 1
|
||||
elif (
|
||||
region[0] <= chunk_region_end <= region[1]
|
||||
and chunk_region_start < region[0]
|
||||
):
|
||||
chunk_region_end = region[0] - 1
|
||||
return chunk_region_start, chunk_region_end
|
||||
|
||||
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
|
||||
"""
|
||||
Find chunk info for a region.
|
||||
|
||||
We are given the region start and region end, and need to find out all chunk info for it.
|
||||
We first loop every dim of start node and end node, to see if we can find dim pair,
|
||||
which is linked in a flow and not computed.
|
||||
If found, we then search flow in the whole region to find out all chunk infos.
|
||||
|
||||
Args:
|
||||
input_trace (List): node's input trace in region
|
||||
output_trace (List): node's output trace in region
|
||||
start_idx (int): region start node index
|
||||
end_idx (int): region end node index
|
||||
|
||||
Returns:
|
||||
chunk_infos: possible regions found
|
||||
"""
|
||||
start_traces = input_trace[start_idx]
|
||||
end_trace = output_trace[end_idx]
|
||||
end_node = self.trace_indice.node_list[end_idx]
|
||||
chunk_infos = []
|
||||
for end_dim, _ in enumerate(end_trace["indice"]):
|
||||
if len(start_traces) > 1:
|
||||
continue
|
||||
for start_node, start_trace in start_traces.items():
|
||||
for start_dim, _ in enumerate(start_trace["indice"]):
|
||||
# dim size cannot be 1
|
||||
if (
|
||||
get_node_shape(end_node)[end_dim] == 1
|
||||
or get_node_shape(start_node)[start_dim] == 1
|
||||
):
|
||||
continue
|
||||
# check index source align
|
||||
if not self.trace_flow.check_index_source(
|
||||
start_dim, start_node, start_idx, end_dim, end_node
|
||||
):
|
||||
continue
|
||||
# check index copmute
|
||||
if not self.trace_flow.check_index_compute(
|
||||
start_idx, end_dim, end_node, end_idx
|
||||
):
|
||||
continue
|
||||
# flow search
|
||||
chunk_info = self.trace_flow.flow_search(
|
||||
start_idx, start_dim, end_idx, end_dim
|
||||
)
|
||||
if chunk_info is None:
|
||||
continue
|
||||
# check index copmute
|
||||
if not self.trace_flow.check_index_duplicate(chunk_info):
|
||||
continue
|
||||
chunk_infos.append(chunk_info)
|
||||
return chunk_infos
|
||||
|
||||
def _search_possible_chunk_regions(
|
||||
self, max_chunk_region: Tuple, peak_node: Node
|
||||
) -> List:
|
||||
"""
|
||||
Search every possible region within the max chunk region.
|
||||
|
||||
Args:
|
||||
max_chunk_region (Tuple)
|
||||
peak_node (Node): peak memory node
|
||||
|
||||
Returns:
|
||||
possible_chunk_region (List)
|
||||
"""
|
||||
possible_chunk_region = []
|
||||
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
|
||||
input_trace = [] # trace of a node's input nodes
|
||||
for _, n in enumerate(self.trace_indice.node_list):
|
||||
cur_trace = {}
|
||||
for arg in n.args:
|
||||
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(
|
||||
arg
|
||||
):
|
||||
cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
|
||||
input_trace.append(cur_trace)
|
||||
|
||||
for start_idx in range(max_chunk_region[0], peak_node + 1):
|
||||
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
||||
# skip non compute nodes
|
||||
if is_non_compute_node(
|
||||
self.trace_indice.node_list[start_idx]
|
||||
) or is_non_compute_node(self.trace_indice.node_list[end_idx]):
|
||||
continue
|
||||
|
||||
# select free dim
|
||||
chunk_info = self._find_chunk_info(
|
||||
input_trace, output_trace, start_idx, end_idx
|
||||
)
|
||||
if len(chunk_info) > 0:
|
||||
possible_chunk_region.extend(chunk_info)
|
||||
return possible_chunk_region
|
||||
|
||||
def _step_search(
|
||||
self,
|
||||
mem_peak: List[float],
|
||||
active_node: List[List[Node]],
|
||||
chunk_infos: List[Dict],
|
||||
) -> Dict:
|
||||
"""
|
||||
Find one chunk region
|
||||
|
||||
The chunk search is as follows:
|
||||
1. find the peak memory node
|
||||
2. find the max chunk region according to the peak memory node
|
||||
3. find all possible chunk regions in the max chunk region
|
||||
4. find the best chunk region for current status
|
||||
|
||||
Args:
|
||||
mem_peak (List): peak memory for every node
|
||||
active_node (List[List[Node]]): active node for every node
|
||||
chunk_infos (List[Dict]): all chunk info
|
||||
|
||||
Returns:
|
||||
best_chunk_region (Dict)
|
||||
"""
|
||||
peak_node = self._find_peak_node(mem_peak)
|
||||
max_chunk_region = self._search_max_chunk_region(
|
||||
active_node, peak_node, chunk_infos
|
||||
)
|
||||
if max_chunk_region == None:
|
||||
return None
|
||||
possible_chunk_regions = self._search_possible_chunk_regions(
|
||||
max_chunk_region, peak_node
|
||||
)
|
||||
best_chunk_region = self.select_chunk._select_best_chunk_region(
|
||||
possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
||||
)
|
||||
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
|
||||
return best_chunk_region
|
||||
|
||||
def _stop_search(self, init_mem_peak, mem_peak):
|
||||
sorted_init_mem_peak = sorted(init_mem_peak)
|
||||
if max(mem_peak) < sorted_init_mem_peak[int(len(sorted_init_mem_peak) * 0.5)]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def search_region(self) -> Dict:
|
||||
"""
|
||||
Search all chunk regions:
|
||||
1. Estimate current memory
|
||||
2. Find best chunk for current memory
|
||||
3. goto 1
|
||||
|
||||
Returns:
|
||||
chunk_infos (Dict)
|
||||
"""
|
||||
chunk_infos = []
|
||||
(
|
||||
init_mem_peak,
|
||||
_,
|
||||
active_node,
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
self.trace_indice.node_list
|
||||
)
|
||||
mem_peak = init_mem_peak
|
||||
|
||||
while True:
|
||||
chunk_info = self._step_search(mem_peak, active_node, chunk_infos)
|
||||
if chunk_info is None:
|
||||
break
|
||||
chunk_infos.append(chunk_info)
|
||||
|
||||
(
|
||||
mem_peak,
|
||||
_,
|
||||
active_node,
|
||||
) = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
self.trace_indice.node_list, chunk_infos
|
||||
)
|
||||
if self._stop_search(init_mem_peak, mem_peak):
|
||||
break
|
||||
if self.print_mem:
|
||||
self.print_mem = False
|
||||
self.estimate_memory.estimate_chunk_inference_mem(
|
||||
self.trace_indice.node_list, chunk_infos, print_mem=True
|
||||
)
|
||||
return chunk_infos
|
|
@ -0,0 +1,224 @@
|
|||
from .estimate_memory import EstimateMemory
|
||||
from .reorder_graph import ReorderGraph
|
||||
from .trace_indice import TraceIndice
|
||||
from .utils import is_non_compute_node
|
||||
|
||||
|
||||
class SelectChunk(object):
|
||||
def __init__(
|
||||
self,
|
||||
trace_indice: TraceIndice,
|
||||
estimate_memory: EstimateMemory,
|
||||
reorder_graph: ReorderGraph,
|
||||
max_memory=None,
|
||||
):
|
||||
self.trace_indice = trace_indice
|
||||
self.estimate_memory = estimate_memory
|
||||
self.reorder_graph = reorder_graph
|
||||
if max_memory is not None:
|
||||
self.stratge = "fit_memory"
|
||||
self.max_memory = max_memory # MB
|
||||
else:
|
||||
self.stratge = "min_memory"
|
||||
|
||||
def _select_best_chunk_region(
|
||||
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
||||
):
|
||||
if self.stratge == "min_memory":
|
||||
best_region = self._select_min_memory_chunk_region(
|
||||
possible_chunk_regions,
|
||||
chunk_infos,
|
||||
peak_node,
|
||||
max_chunk_region,
|
||||
mem_peak,
|
||||
)
|
||||
elif self.stratge == "fit_memory":
|
||||
best_region = self._select_fit_memory_chunk_region(
|
||||
possible_chunk_regions,
|
||||
chunk_infos,
|
||||
peak_node,
|
||||
max_chunk_region,
|
||||
mem_peak,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError()
|
||||
return best_region
|
||||
|
||||
def _select_fit_memory_chunk_region(
|
||||
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
||||
):
|
||||
# stop chunk if max memory satisfy memory limit
|
||||
if max(mem_peak) < self.max_memory:
|
||||
return None
|
||||
|
||||
# remove illegal regions
|
||||
illegal_regions = []
|
||||
for i in possible_chunk_regions:
|
||||
if not self._is_legal_region(i, chunk_infos):
|
||||
illegal_regions.append(i)
|
||||
for i in illegal_regions:
|
||||
if i in possible_chunk_regions:
|
||||
possible_chunk_regions.remove(i)
|
||||
|
||||
if len(possible_chunk_regions) == 0:
|
||||
return None
|
||||
|
||||
# get mem for chunk region
|
||||
regions_dict = []
|
||||
for region in possible_chunk_regions:
|
||||
cur_region = region.copy()
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
|
||||
self.trace_indice.node_list, cur_region
|
||||
)
|
||||
cur_chunk_infos = chunk_infos + [cur_region]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
cur_node_list, cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_region_peak = cur_mem_peak[
|
||||
max_chunk_region[0] : max_chunk_region[1] + 1
|
||||
]
|
||||
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
|
||||
if cur_chunk_region_max_peak < self.max_memory:
|
||||
regions_dict.append(
|
||||
{
|
||||
"chunk_info": region,
|
||||
"chunk_max_mem": cur_chunk_region_max_peak,
|
||||
"chunk_len": self._get_compute_node_num(
|
||||
region["region"][0], region["region"][1]
|
||||
),
|
||||
"reorder_chunk_info": cur_region,
|
||||
"reorder_node_list": cur_node_list,
|
||||
}
|
||||
)
|
||||
# no region found
|
||||
if len(regions_dict) == 0:
|
||||
raise RuntimeError("Search failed. Try a larger memory threshold.")
|
||||
|
||||
# select the min chunk len
|
||||
chunk_len = [i["chunk_len"] for i in regions_dict]
|
||||
best_region_idx = chunk_len.index(min(chunk_len))
|
||||
best_region = regions_dict[best_region_idx]
|
||||
|
||||
# get max chunk size
|
||||
best_region = self._get_fit_chunk_size(best_region, chunk_infos)
|
||||
return best_region
|
||||
|
||||
def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos):
|
||||
chunk_size = 1
|
||||
reorder_chunk_info = chunk_region_dict["reorder_chunk_info"]
|
||||
reorder_chunk_info["chunk_size"] = chunk_size
|
||||
cur_chunk_max_mem = 0
|
||||
# search a region
|
||||
while cur_chunk_max_mem < self.max_memory:
|
||||
chunk_size *= 2
|
||||
reorder_chunk_info["chunk_size"] = chunk_size
|
||||
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
chunk_region_dict["reorder_node_list"], cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_max_mem = max(
|
||||
cur_mem_peak[
|
||||
reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1]
|
||||
+ 1
|
||||
]
|
||||
)
|
||||
# search exact size
|
||||
chunk_info = chunk_region_dict["chunk_info"]
|
||||
chunk_info["chunk_size"] = self._chunk_size_binary_search(
|
||||
chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos
|
||||
)
|
||||
return chunk_info
|
||||
|
||||
def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
|
||||
if left >= 16:
|
||||
gap = 4
|
||||
else:
|
||||
gap = 1
|
||||
chunk_info = chunk_region_dict["reorder_chunk_info"]
|
||||
while right >= left + gap:
|
||||
mid = int((left + right) / 2 + 0.5)
|
||||
chunk_info["chunk_size"] = mid
|
||||
cur_chunk_infos = chunk_infos + [chunk_info]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
chunk_region_dict["reorder_node_list"], cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_max_mem = max(
|
||||
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
|
||||
)
|
||||
if cur_chunk_max_mem >= self.max_memory:
|
||||
right = mid - gap
|
||||
else:
|
||||
left = mid + gap
|
||||
return left
|
||||
|
||||
def _get_compute_node_num(self, start, end):
|
||||
count = 0
|
||||
for i in self.trace_indice.node_list[start : end + 1]:
|
||||
if not is_non_compute_node(i):
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def _select_min_memory_chunk_region(
|
||||
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
||||
):
|
||||
# remove illegal regions
|
||||
illegal_regions = []
|
||||
for i in possible_chunk_regions:
|
||||
if not self._is_legal_region(i, chunk_infos):
|
||||
illegal_regions.append(i)
|
||||
for i in illegal_regions:
|
||||
if i in possible_chunk_regions:
|
||||
possible_chunk_regions.remove(i)
|
||||
|
||||
if len(possible_chunk_regions) == 0:
|
||||
return None
|
||||
|
||||
# get mem for chunk region
|
||||
regions_dict = []
|
||||
for region in possible_chunk_regions:
|
||||
cur_region = region.copy()
|
||||
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
|
||||
self.trace_indice.node_list, cur_region
|
||||
)
|
||||
cur_chunk_infos = chunk_infos + [cur_region]
|
||||
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||
cur_node_list, cur_chunk_infos
|
||||
)[0]
|
||||
cur_chunk_region_peak = cur_mem_peak[
|
||||
max_chunk_region[0] : max_chunk_region[1] + 1
|
||||
]
|
||||
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
|
||||
regions_dict.append(
|
||||
{
|
||||
"chunk_info": region,
|
||||
"chunk_max_mem": cur_chunk_region_max_peak,
|
||||
"chunk_len": self._get_compute_node_num(
|
||||
region["region"][0], region["region"][1]
|
||||
),
|
||||
"reorder_chunk_info": cur_region,
|
||||
"reorder_node_list": cur_node_list,
|
||||
}
|
||||
)
|
||||
|
||||
# select the min mem
|
||||
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict]
|
||||
best_region_idx = chunk_max_mem.index(min(chunk_max_mem))
|
||||
best_region = regions_dict[best_region_idx]["chunk_info"]
|
||||
if best_region is not None:
|
||||
best_region["chunk_size"] = 1
|
||||
return best_region
|
||||
|
||||
def _is_legal_region(self, cur_chunk_info, chunk_infos):
|
||||
(chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
|
||||
if cur_chunk_info in chunk_infos:
|
||||
return False
|
||||
if chunk_region_end < chunk_region_start:
|
||||
return False
|
||||
for i in chunk_infos:
|
||||
region = i["region"]
|
||||
if not (
|
||||
(chunk_region_start > region[1] and chunk_region_end > region[1])
|
||||
or (chunk_region_start < region[0] and chunk_region_end < region[0])
|
||||
):
|
||||
return False
|
||||
return True
|
|
@ -0,0 +1,420 @@
|
|||
from .trace_indice import TraceIndice
|
||||
from .utils import (
|
||||
find_chunk_all_input_nodes,
|
||||
find_chunk_compute_input_and_output_nodes,
|
||||
find_idx_by_name,
|
||||
get_node_shape,
|
||||
is_non_compute_node,
|
||||
is_non_compute_node_except_placeholder,
|
||||
)
|
||||
|
||||
|
||||
class TraceFlow(object):
|
||||
def __init__(self, trace_indice: TraceIndice) -> None:
|
||||
self.trace_indice = trace_indice
|
||||
|
||||
def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node):
|
||||
"""
|
||||
Check 2 given index: one index should be source of the other
|
||||
Args:
|
||||
start_idx(int): start node chunk dim
|
||||
start_node(node): start node
|
||||
end_idx(int): end node chunk dim
|
||||
end_node(node): end node
|
||||
|
||||
Returns:
|
||||
bool: True if check pass
|
||||
"""
|
||||
start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list)
|
||||
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
|
||||
end_node_trace_source = end_node_trace["source"][end_dim]
|
||||
sorted_source = sorted(
|
||||
end_node_trace_source.items(), key=lambda d: d[0], reverse=True
|
||||
)
|
||||
for node_idx, node_dim in sorted_source:
|
||||
if node_idx == start_node_idx and start_dim in node_dim:
|
||||
return True
|
||||
# it means we meet a node outside the loop, and the node is not input node
|
||||
if node_idx < start_idx:
|
||||
return False
|
||||
return False
|
||||
|
||||
def check_index_compute(self, start_idx, end_dim, end_node, end_idx):
|
||||
"""
|
||||
Check 2 given index: check they haven't been computed in the source trace.
|
||||
Args:
|
||||
start_idx(int): start node chunk dim
|
||||
start_node(node): start node
|
||||
end_idx(int): end node chunk dim
|
||||
end_node(node): end node
|
||||
|
||||
Returns:
|
||||
bool: True if check pass
|
||||
"""
|
||||
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
|
||||
end_node_compute = end_node_trace["compute"][end_dim]
|
||||
if any(start_idx <= i <= end_idx for i in end_node_compute):
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
||||
node_from_source = self.trace_indice._find_source_trace_from_node(node_from)
|
||||
dim_source = node_from_source[node_from_dim]
|
||||
node_to_idx = find_idx_by_name(node_to.name, self.trace_indice.node_list)
|
||||
for k, v in dim_source.items():
|
||||
if k == node_to_idx:
|
||||
return v
|
||||
return None
|
||||
|
||||
def _find_inherit_dim(self, input_node, input_dim, node):
|
||||
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
|
||||
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
|
||||
for node_dim in range(len(get_node_shape(node))):
|
||||
if (
|
||||
input_node_idx in node_trace_source[node_dim]
|
||||
and input_dim[0] in node_trace_source[node_dim][input_node_idx]
|
||||
):
|
||||
return node_dim
|
||||
return None
|
||||
|
||||
def check_index_duplicate(self, chunk_infos, return_dim=False):
|
||||
input_dim_after_node = {}
|
||||
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
|
||||
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
|
||||
inherit_dim = self._find_inherit_dim(
|
||||
input_node, v, self.trace_indice.node_list[k]
|
||||
)
|
||||
if inherit_dim:
|
||||
input_dim_after_node[k] = inherit_dim
|
||||
|
||||
for node in self.trace_indice.node_list[
|
||||
chunk_infos["region"][0] : chunk_infos["region"][1] + 1
|
||||
]:
|
||||
if is_non_compute_node_except_placeholder(node):
|
||||
continue
|
||||
count = 0
|
||||
duplicate_dims = []
|
||||
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
|
||||
for node_dim in range(len(get_node_shape(node))):
|
||||
duplicate_dim = []
|
||||
duplicate_flag = False
|
||||
dim_source = node_trace_source[node_dim]
|
||||
for k, v in dim_source.items():
|
||||
if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]:
|
||||
if k in input_dim_after_node and input_dim_after_node[k] in v:
|
||||
duplicate_flag = True
|
||||
duplicate_dim.append((k, v))
|
||||
duplicate_dims.append(duplicate_dim)
|
||||
if duplicate_flag:
|
||||
count += 1
|
||||
|
||||
if count > 1:
|
||||
if return_dim:
|
||||
return False, duplicate_dims
|
||||
else:
|
||||
return False
|
||||
if return_dim:
|
||||
return True, None
|
||||
else:
|
||||
return True
|
||||
|
||||
def _assgin_single_node_flow(
|
||||
self,
|
||||
arg_node,
|
||||
start_idx,
|
||||
end_idx,
|
||||
cur_node_dim,
|
||||
cur_node_compute,
|
||||
cur_node_source,
|
||||
cur_node_fix_dim,
|
||||
all_node_info,
|
||||
next_node_list,
|
||||
):
|
||||
arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list)
|
||||
# arg in chunk range or be inputs
|
||||
if not (start_idx <= arg_idx < end_idx):
|
||||
return True
|
||||
|
||||
# find arg dim
|
||||
if cur_node_dim is not None:
|
||||
# dim is computed
|
||||
if arg_idx in cur_node_compute[cur_node_dim]:
|
||||
return False
|
||||
if arg_idx not in cur_node_source[cur_node_dim]:
|
||||
arg_dim = None
|
||||
else:
|
||||
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
|
||||
else:
|
||||
arg_dim = None
|
||||
|
||||
# get fix dim
|
||||
arg_fix_dim = []
|
||||
if cur_node_dim is not None:
|
||||
for i in cur_node_fix_dim:
|
||||
fix_dim_source = cur_node_source[i]
|
||||
if arg_idx in fix_dim_source:
|
||||
arg_fix_dim.append(fix_dim_source[arg_idx][0])
|
||||
|
||||
# if already in node_info, arg dim must be same
|
||||
if arg_node in all_node_info:
|
||||
if all_node_info[arg_node]["chunk_dim"] != arg_dim:
|
||||
return False
|
||||
all_node_info[arg_node]["fix_dim"] = list(
|
||||
set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)
|
||||
)
|
||||
# else add it to list
|
||||
else:
|
||||
all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim}
|
||||
|
||||
next_node_list.append(arg_node)
|
||||
return True
|
||||
|
||||
def _get_all_node_info(self, end_dim, start_idx, end_idx):
|
||||
cur_node_list = [
|
||||
self.trace_indice.node_list[end_idx]
|
||||
] # start from the last node
|
||||
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
||||
|
||||
while len(cur_node_list) > 0:
|
||||
next_node_list = []
|
||||
|
||||
for cur_node in cur_node_list:
|
||||
# get cur node info
|
||||
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
|
||||
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
||||
if cur_node_chunk_dim:
|
||||
cur_node_compute = self.trace_indice._find_compute_trace_from_node(
|
||||
cur_node
|
||||
)
|
||||
cur_node_source = self.trace_indice._find_source_trace_from_node(
|
||||
cur_node
|
||||
)
|
||||
else:
|
||||
cur_node_compute = cur_node_source = None
|
||||
|
||||
# get all valid args
|
||||
arg_list = []
|
||||
for arg in cur_node.args:
|
||||
if type(arg) != type(cur_node):
|
||||
continue
|
||||
if is_non_compute_node(arg):
|
||||
continue
|
||||
arg_list.append(arg)
|
||||
flow_flag = self._assgin_single_node_flow(
|
||||
arg,
|
||||
start_idx,
|
||||
end_idx,
|
||||
cur_node_chunk_dim,
|
||||
cur_node_compute,
|
||||
cur_node_source,
|
||||
cur_node_fix_dim,
|
||||
all_node_info,
|
||||
next_node_list,
|
||||
)
|
||||
if flow_flag == False:
|
||||
return None
|
||||
|
||||
if len(arg_list) == 2:
|
||||
if any(i in cur_node.name for i in ["add", "mul"]):
|
||||
for arg in arg_list:
|
||||
if not (
|
||||
start_idx
|
||||
<= find_idx_by_name(
|
||||
arg.name, self.trace_indice.node_list
|
||||
)
|
||||
< end_idx
|
||||
):
|
||||
continue
|
||||
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
|
||||
arg_fix_dim = all_node_info[arg]["fix_dim"]
|
||||
arg_shape = get_node_shape(arg)
|
||||
# add all dim as fix dim except chunk dim
|
||||
for i, shape in enumerate(arg_shape):
|
||||
if shape != 1 and i != cur_node_chunk_dim:
|
||||
if i == arg_chunk_dim:
|
||||
return None
|
||||
if i not in arg_fix_dim:
|
||||
arg_fix_dim.append(i)
|
||||
elif "einsum" in cur_node.name:
|
||||
pass
|
||||
elif "matmul" in cur_node.name:
|
||||
pass
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
cur_node_list = next_node_list
|
||||
return all_node_info
|
||||
|
||||
def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info):
|
||||
inputs_dim = []
|
||||
remove_inputs = []
|
||||
for input_node in inputs:
|
||||
input_dict = {}
|
||||
input_node_idx = find_idx_by_name(
|
||||
input_node.name, self.trace_indice.node_list
|
||||
)
|
||||
for user in input_node.users.keys():
|
||||
if is_non_compute_node(user):
|
||||
continue
|
||||
user_idx = find_idx_by_name(user.name, self.trace_indice.node_list)
|
||||
if start_idx <= user_idx <= end_idx:
|
||||
chunk_dim = all_node_info[user]["chunk_dim"]
|
||||
if chunk_dim is not None:
|
||||
user_source = self.trace_indice._find_source_trace_from_node(
|
||||
user
|
||||
)[chunk_dim]
|
||||
if input_node_idx in user_source:
|
||||
input_dict[user_idx] = user_source[input_node_idx]
|
||||
else:
|
||||
return None, None
|
||||
if len(input_dict) == 0:
|
||||
remove_inputs.append(input_node)
|
||||
else:
|
||||
inputs_dim.append(input_dict)
|
||||
for i in remove_inputs:
|
||||
if i in inputs:
|
||||
inputs.remove(i)
|
||||
return inputs, inputs_dim
|
||||
|
||||
def _get_prepose_nodes(self, all_node_info, start_idx, end_idx):
|
||||
# get all possible prepose nodes
|
||||
maybe_prepose_nodes = []
|
||||
for node, node_info in all_node_info.items():
|
||||
if node_info["chunk_dim"] is None:
|
||||
maybe_prepose_nodes.append(node)
|
||||
maybe_prepose_nodes.sort(
|
||||
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list),
|
||||
reverse=True,
|
||||
) # from last node to first node
|
||||
prepose_nodes = []
|
||||
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
|
||||
while len(maybe_prepose_nodes) > 0:
|
||||
tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]]
|
||||
tmp_cur_related_prepose_nodes = []
|
||||
prepose_flag = True
|
||||
|
||||
# loop cur node's all arg until out of chunk
|
||||
while len(tmp_cur_prepose_nodes) > 0:
|
||||
if prepose_flag == False:
|
||||
break
|
||||
tmp_next_prepose_nodes = []
|
||||
tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes)
|
||||
for cur_prepose_node in tmp_cur_prepose_nodes:
|
||||
if prepose_flag == False:
|
||||
break
|
||||
for cur_prepose_node_arg in cur_prepose_node.args:
|
||||
if type(cur_prepose_node_arg) != type(cur_prepose_node):
|
||||
continue
|
||||
# out of loop
|
||||
if not (
|
||||
start_idx
|
||||
<= find_idx_by_name(
|
||||
cur_prepose_node_arg.name, self.trace_indice.node_list
|
||||
)
|
||||
< end_idx
|
||||
):
|
||||
continue
|
||||
# compute op in loop
|
||||
elif cur_prepose_node_arg in all_node_info:
|
||||
if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None:
|
||||
tmp_next_prepose_nodes.append(cur_prepose_node_arg)
|
||||
else:
|
||||
prepose_flag = False
|
||||
break
|
||||
# non compute op
|
||||
else:
|
||||
tmp_next_prepose_nodes.append(cur_prepose_node_arg)
|
||||
tmp_cur_prepose_nodes = tmp_next_prepose_nodes
|
||||
|
||||
if prepose_flag == False:
|
||||
maybe_prepose_nodes.remove(maybe_prepose_nodes[0])
|
||||
continue
|
||||
else:
|
||||
for n in tmp_cur_related_prepose_nodes:
|
||||
if n not in prepose_nodes:
|
||||
prepose_nodes.append(n)
|
||||
if n in maybe_prepose_nodes:
|
||||
maybe_prepose_nodes.remove(n)
|
||||
# sort by index
|
||||
prepose_nodes.sort(
|
||||
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list)
|
||||
)
|
||||
|
||||
return prepose_nodes
|
||||
|
||||
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
|
||||
# we need to log input nodes to avoid deleteing them in the loop
|
||||
chunk_node_list = self.trace_indice.node_list[start_idx : end_idx + 1]
|
||||
# also need to get some prepose node's arg out of non_chunk_inputs
|
||||
for n in chunk_info["args"]["prepose_nodes"]:
|
||||
chunk_node_list.remove(n)
|
||||
non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list)
|
||||
for i in non_chunk_inputs:
|
||||
if i not in chunk_info["inputs"]:
|
||||
chunk_info["inputs_non_chunk"].append(i)
|
||||
return chunk_info
|
||||
|
||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
||||
self.trace_indice.node_list[start_idx : end_idx + 1]
|
||||
)
|
||||
# only single ouput
|
||||
if len(outputs) > 1:
|
||||
return None
|
||||
|
||||
# get every node's chunk dim and fix dim
|
||||
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
|
||||
if all_node_info is None:
|
||||
return None
|
||||
|
||||
# get input nodes' chunk dim
|
||||
inputs, inputs_dim = self._get_input_nodes_dim(
|
||||
inputs, start_idx, end_idx, all_node_info
|
||||
)
|
||||
if inputs is None:
|
||||
return None
|
||||
|
||||
chunk_info = {
|
||||
"region": (start_idx, end_idx),
|
||||
"inputs": inputs,
|
||||
"inputs_non_chunk": [],
|
||||
"inputs_dim": inputs_dim,
|
||||
"outputs": outputs,
|
||||
"outputs_dim": end_dim,
|
||||
"node_chunk_dim": all_node_info,
|
||||
"args": {},
|
||||
}
|
||||
|
||||
# move useless nodes ahead of loop
|
||||
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(
|
||||
all_node_info, start_idx, end_idx
|
||||
)
|
||||
|
||||
# find non chunk inputs
|
||||
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
|
||||
|
||||
# reassgin reshape size, some size may have changed due to chunk
|
||||
chunk_info = self._reassgin_reshape_size(chunk_info)
|
||||
|
||||
return chunk_info
|
||||
|
||||
def _reassgin_reshape_size(self, chunk_info):
|
||||
chunk_region = chunk_info["region"]
|
||||
reshape_size = {}
|
||||
chunk_shape = get_node_shape(chunk_info["outputs"][0])[
|
||||
chunk_info["outputs_dim"]
|
||||
]
|
||||
for node in self.trace_indice.node_list[chunk_region[0] : chunk_region[1] + 1]:
|
||||
if any(i in node.name for i in ["reshape", "view"]):
|
||||
reshape_args = node.args[1:]
|
||||
reshape_log = self.trace_indice.indice_view_list[node]
|
||||
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
|
||||
reshape_size[node.name] = {}
|
||||
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
|
||||
if reshape_arg_dim in reshape_log["dim_to"]:
|
||||
continue
|
||||
if reshape_arg_dim == chunk_dim:
|
||||
reshape_size[node.name][reshape_arg.name] = (
|
||||
"min(chunk_size, %d - chunk_idx)" % chunk_shape
|
||||
)
|
||||
chunk_info["reshape_size"] = reshape_size
|
||||
return chunk_info
|
|
@ -0,0 +1,559 @@
|
|||
import copy
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .utils import find_idx_by_name, get_node_shape
|
||||
|
||||
|
||||
class TraceIndice(object):
|
||||
"""
|
||||
Trace all indice infomation for every node.
|
||||
|
||||
Indice is a logical concept. Equal dims can been treated as one indice.
|
||||
eg. dim(x1) = [a, b, c]
|
||||
dim(x2) = [d, e, f]
|
||||
and we have x3 = x1 * x2.
|
||||
then a=d, b=e, c=f, due to the broadcast property,
|
||||
dim(x1)=dim(x2)=dim(x3)=[a, b, c]
|
||||
This class will record every node's dims' indice, compute and source.
|
||||
|
||||
Attibutes:
|
||||
node_list (List)
|
||||
indice_trace_list (List): [{"indice": [...], "compute": [...], "source": [...]}, {...}]
|
||||
indice_view_list (Dict): not used for now
|
||||
indice_count (int): record indice number
|
||||
|
||||
Args:
|
||||
node_list (List)
|
||||
"""
|
||||
|
||||
def __init__(self, node_list: List) -> None:
|
||||
self.node_list = node_list
|
||||
self.indice_trace_list = self._init_indice_trace_list()
|
||||
self.indice_view_list = {}
|
||||
self.indice_count = -1
|
||||
|
||||
def _init_indice_trace_list(self):
|
||||
indice_trace_list = []
|
||||
for n in self.node_list:
|
||||
if get_node_shape(n) != None:
|
||||
cur_trace = {
|
||||
"indice": [None for _ in range(len(get_node_shape(n)))],
|
||||
"compute": [[] for _ in range(len(get_node_shape(n)))],
|
||||
"source": [{} for _ in range(len(get_node_shape(n)))],
|
||||
}
|
||||
else:
|
||||
cur_trace = {"indice": [], "compute": [], "source": []}
|
||||
indice_trace_list.append(cur_trace)
|
||||
return indice_trace_list
|
||||
|
||||
def _add_indice(self):
|
||||
"""
|
||||
Update the count and return it. To record the idx number.
|
||||
|
||||
Returns:
|
||||
indice_count: int
|
||||
"""
|
||||
self.indice_count += 1
|
||||
return self.indice_count
|
||||
|
||||
def _del_dim(self, idx, dim_idx):
|
||||
self.indice_trace_list[idx]["indice"].pop(dim_idx)
|
||||
self.indice_trace_list[idx]["compute"].pop(dim_idx)
|
||||
self.indice_trace_list[idx]["source"].pop(dim_idx)
|
||||
|
||||
def _add_dim(self, node_idx, dim_idx):
|
||||
self.indice_trace_list[node_idx]["indice"].insert(dim_idx, self._add_indice())
|
||||
self.indice_trace_list[node_idx]["compute"].insert(dim_idx, [])
|
||||
self.indice_trace_list[node_idx]["source"].insert(dim_idx, {})
|
||||
|
||||
def _transform_indice(self, node, node_dim):
|
||||
node_idx = self._find_indice_trace_from_node(node)
|
||||
dims = list(range(len(node_idx)))
|
||||
return dims[node_dim]
|
||||
|
||||
def _inherit_indice(self, node_from, node_from_dim, node_to, node_to_dim):
|
||||
node_from_dim = self._transform_indice(node_from, node_from_dim)
|
||||
node_to_dim = self._transform_indice(node_to, node_to_dim)
|
||||
node_from_trace = self._find_trace_from_node(node_from)
|
||||
node_to_trace = self._find_trace_from_node(node_to)
|
||||
node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim]
|
||||
node_to_trace["compute"][node_to_dim] = copy.deepcopy(
|
||||
node_from_trace["compute"][node_from_dim]
|
||||
)
|
||||
self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True)
|
||||
|
||||
def _inherit_all_computation(self, node_from, node_to):
|
||||
node_from_compute = self._find_compute_trace_from_node(node_from)
|
||||
node_to_compute = self._find_compute_trace_from_node(node_to)
|
||||
assert len(node_from_compute) == len(node_to_compute)
|
||||
for i in range(len(node_from_compute)):
|
||||
self._add_source(node_from, i, node_to, i)
|
||||
node_to_compute[i] = copy.deepcopy(node_from_compute[i])
|
||||
|
||||
def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False):
|
||||
node_from_dim = self._transform_indice(node_from, node_from_dim)
|
||||
node_from_trace_source = self._find_source_trace_from_node(node_from)
|
||||
node_to_dim = self._transform_indice(node_to, node_to_dim)
|
||||
node_to_trace_source = self._find_source_trace_from_node(node_to)
|
||||
node_from_idx = find_idx_by_name(node_from.name, self.node_list)
|
||||
if init:
|
||||
node_to_trace_source[node_to_dim] = {}
|
||||
# add dim to cur new source
|
||||
if node_from_idx not in node_to_trace_source[node_to_dim]:
|
||||
node_to_trace_source[node_to_dim][node_from_idx] = [node_from_dim]
|
||||
else:
|
||||
if node_from_dim not in node_to_trace_source[node_to_dim][node_from_idx]:
|
||||
node_to_trace_source[node_to_dim][node_from_idx].append(node_from_dim)
|
||||
# update inputs source
|
||||
for node_idx, node_dim in node_from_trace_source[node_from_dim].items():
|
||||
if node_idx not in node_to_trace_source[node_to_dim]:
|
||||
node_to_trace_source[node_to_dim][node_idx] = copy.deepcopy(node_dim)
|
||||
else:
|
||||
for d in node_dim:
|
||||
if d not in node_to_trace_source[node_to_dim][node_idx]:
|
||||
node_to_trace_source[node_to_dim][node_idx].append(d)
|
||||
|
||||
def _mark_computation_from_node(self, node_from, node_to, exclude=None):
|
||||
if exclude == None:
|
||||
exclude = []
|
||||
else:
|
||||
exclude = [self._transform_indice(node_to, i) for i in exclude]
|
||||
node_from_compute = self._find_compute_trace_from_node(node_from)
|
||||
node_to_compute = self._find_compute_trace_from_node(node_to)
|
||||
# assert len(node_from_compute) == len(node_to_compute)
|
||||
for i in range(-1, -min(len(node_from_compute), len(node_to_compute)) - 1, -1):
|
||||
if self._transform_indice(node_to, i) in exclude:
|
||||
continue
|
||||
self._add_source(node_from, i, node_to, i)
|
||||
for j in node_from_compute[i]:
|
||||
if j not in node_to_compute[i]:
|
||||
node_to_compute[i].append(j)
|
||||
|
||||
def _mark_computation(self, node, idx, dim):
|
||||
"""
|
||||
Mark some dims of node as computed.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
idx (int): node index
|
||||
dim (list or int): dims to be marked as computed
|
||||
"""
|
||||
if isinstance(dim, int):
|
||||
dim = [dim]
|
||||
dims = list(range(len(get_node_shape(node))))
|
||||
for d in dim:
|
||||
cur_dim = dims[d]
|
||||
if idx not in self.indice_trace_list[idx]["compute"][cur_dim]:
|
||||
self.indice_trace_list[idx]["compute"][cur_dim].append(idx)
|
||||
|
||||
def _find_trace_from_node(self, node):
|
||||
"""
|
||||
Find node idx and compute trace by the node.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
Returns:
|
||||
idx (list): idx of the node
|
||||
compute (list): computed idx of the node.
|
||||
"""
|
||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||
node_dict = self.indice_trace_list[node_idx]
|
||||
return node_dict
|
||||
|
||||
def _find_source_trace_from_node(self, node):
|
||||
"""
|
||||
Find node source trace by the node.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
Returns:
|
||||
idx (list): idx of the node
|
||||
compute (list): computed idx of the node.
|
||||
"""
|
||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||
node_dict = self.indice_trace_list[node_idx]
|
||||
return node_dict["source"]
|
||||
|
||||
def _find_indice_trace_from_node(self, node):
|
||||
"""
|
||||
Find node idx trace by the node.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
Returns:
|
||||
idx (list): idx of the node
|
||||
"""
|
||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||
return self.indice_trace_list[node_idx]["indice"]
|
||||
|
||||
def _find_compute_trace_from_node(self, node):
|
||||
"""
|
||||
Find node compute trace by the node.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
Returns:
|
||||
compute (list): computed idx of the node.
|
||||
"""
|
||||
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||
return self.indice_trace_list[node_idx]["compute"]
|
||||
|
||||
def _assign_indice_as_input(self, node, node_idx, input_node=None):
|
||||
"""
|
||||
Assign node's trace as its input node.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
if input_node == None:
|
||||
input_node = node.args[0]
|
||||
input_node_idx = find_idx_by_name(input_node.name, self.node_list)
|
||||
input_node_idx_trace = self.indice_trace_list[input_node_idx]["indice"]
|
||||
|
||||
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
||||
self.indice_trace_list[node_idx]["indice"] = new_idx_trace
|
||||
|
||||
self._inherit_all_computation(input_node, node)
|
||||
|
||||
def _assign_all_indice(self, node, node_idx):
|
||||
"""
|
||||
Add new indice for all node's dims.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
shape = node.meta["tensor_meta"].shape
|
||||
new_trace = []
|
||||
for _ in shape:
|
||||
new_trace.append(self._add_indice())
|
||||
self.indice_trace_list[node_idx]["indice"] = new_trace
|
||||
|
||||
def _assign_transpose_indice(self, node, node_idx):
|
||||
"""
|
||||
Assign indice for transpose op.
|
||||
1. swap input's dim according to transpose args
|
||||
2. inherit input's computation
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
input_node = node.args[0]
|
||||
tranpose_dim = node.args[1:]
|
||||
|
||||
self._assign_indice_as_input(node, node_idx, input_node)
|
||||
self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0])
|
||||
self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1])
|
||||
|
||||
def _assign_permute_indice(self, node, node_idx):
|
||||
"""
|
||||
Assign indice for permute op.
|
||||
1. swap input's dim according to permute args
|
||||
2. inherit input's computation
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
permute_dim = node.args[1:]
|
||||
input_node = node.args[0]
|
||||
|
||||
self._assign_indice_as_input(node, node_idx, input_node)
|
||||
for idx, d in enumerate(permute_dim):
|
||||
self._inherit_indice(input_node, d, node, idx)
|
||||
|
||||
def _assign_linear_indice(self, node, node_idx):
|
||||
"""
|
||||
Assign indice for linear op.
|
||||
1. copy trace from input node and change last indice accroding to weight
|
||||
2. mark equal for input node last indice, weight first dim and bias dim.
|
||||
3. inherit input's computation, mark computation for last dim.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
if len(node.args) == 2:
|
||||
_, weight = node.args
|
||||
else:
|
||||
_, weight, _ = node.args
|
||||
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
self._inherit_indice(weight, 1, node, -1)
|
||||
|
||||
self._mark_computation(node, node_idx, [-1])
|
||||
|
||||
def _assign_matmul_indice(self, node, node_idx):
|
||||
"""
|
||||
Assign indice for matmul op.
|
||||
1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length)
|
||||
2. mark equal for input matmul_left -1 indice and matmul_right -2 dim.
|
||||
3. inherit matmul_left and matmul_right computation, mark computation for last dim.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
matmul_left, matmul_right = node.args
|
||||
|
||||
assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right))
|
||||
self._assign_indice_as_input(node, node_idx, matmul_left)
|
||||
self._inherit_indice(matmul_right, -1, node, -1)
|
||||
|
||||
self._mark_computation_from_node(matmul_right, node, [-1, -2])
|
||||
self._mark_computation(node, node_idx, [-1])
|
||||
|
||||
def _assign_layernorm_indice(self, node, idx):
|
||||
"""
|
||||
Assign indice for layernorm op.
|
||||
1. assign indice as input node
|
||||
2. inherit computation and mark last 2 dims as computed.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self._assign_indice_as_input(node, idx)
|
||||
self._mark_computation(node, idx, [-1])
|
||||
|
||||
def _assign_elementwise_indice(self, node, idx):
|
||||
"""
|
||||
Assign indice for element-wise op (eg. relu sigmoid add mul).
|
||||
1. assign indice as input node
|
||||
2. inherit computation from all input nodes.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self._assign_indice_as_input(node, idx)
|
||||
nodes_in = []
|
||||
for node_in in node.args:
|
||||
if type(node_in) == type(node):
|
||||
nodes_in.append(node_in)
|
||||
self._mark_computation_from_node(node_in, node)
|
||||
assert len(nodes_in) <= 2
|
||||
|
||||
def _assgin_no_change_indice(self, node, idx):
|
||||
self._assign_indice_as_input(node, idx)
|
||||
for node_in in node.args:
|
||||
if type(node_in) == type(node):
|
||||
self._mark_computation_from_node(node_in, node)
|
||||
|
||||
def _assign_einsum_indice(self, node, idx):
|
||||
"""
|
||||
Assign indice for einsum op.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
patterns = node.args[0]
|
||||
input_nodes = node.args[1:]
|
||||
|
||||
patterns = patterns.replace(" ", "")
|
||||
left, right = patterns.split("->")
|
||||
left = left.split(",")
|
||||
|
||||
all_index = []
|
||||
for i in left:
|
||||
for c in i:
|
||||
all_index.append(c)
|
||||
all_index = set(all_index)
|
||||
|
||||
for right_idx, right_indice in enumerate(right):
|
||||
for left_idx, left_str in enumerate(left):
|
||||
if right_indice in left_str:
|
||||
source_idx = left_str.index(right_indice)
|
||||
self._inherit_indice(
|
||||
input_nodes[left_idx], source_idx, node, right_idx
|
||||
)
|
||||
|
||||
def _assign_softmax_indice(self, node, idx):
|
||||
"""
|
||||
Assign indice for softmax op.
|
||||
1. assign indice as input node
|
||||
2. inherit computation and mark softmax dim as computed.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self._assign_indice_as_input(node, idx)
|
||||
self._mark_computation(node, idx, [node.kwargs["dim"]])
|
||||
|
||||
def _assign_unsqueeze_indice(self, node, node_idx):
|
||||
"""
|
||||
Assign indice for unsqueeze op.
|
||||
1. assign new indice for unsqueeze dim
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self._del_dim(node_idx, -1)
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
self._add_dim(node_idx, node.args[1])
|
||||
|
||||
def _assign_dropout_indice(self, node, node_idx):
|
||||
"""
|
||||
Assign indice for unsqueeze op.
|
||||
1. assign new indice for unsqueeze dim
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self._assign_indice_as_input(node, node_idx)
|
||||
|
||||
def _assign_ones_like_indice(self, node, node_idx):
|
||||
"""
|
||||
Assign indice for oneslike op.
|
||||
1. assign new indice for all dim
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
self._assign_all_indice(node, node_idx)
|
||||
|
||||
def _assign_view_reshape_indice(self, node, node_idx):
|
||||
"""
|
||||
Assign indice for view and reshape op.
|
||||
1. get origin shape and target shape by meta info.
|
||||
2. compute the real value of -1 in target shape.
|
||||
3. determine changed dim, and assgin indice for generated dim.
|
||||
4. log changed dim and generated dim for restore
|
||||
5. inherit computation.
|
||||
6. TODO: look into view list to see whether the view is associated with other,
|
||||
if so assgin equal dim according to previous view.
|
||||
|
||||
Args:
|
||||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
# get data, turn into number
|
||||
origin_node = node.args[0]
|
||||
origin_shape = origin_node.meta["tensor_meta"].shape
|
||||
target_shape = []
|
||||
for i in range(1, len(node.args)):
|
||||
if isinstance(node.args[i], int):
|
||||
target_shape.append(node.args[i])
|
||||
else:
|
||||
target_shape.append(node.args[i].meta["fwd_out"][0])
|
||||
|
||||
# compute the value of -1
|
||||
if -1 in target_shape:
|
||||
origin_product = 1
|
||||
for i in origin_shape:
|
||||
origin_product *= i
|
||||
target_product = -1
|
||||
for i in target_shape:
|
||||
target_product *= i
|
||||
shape_idx = target_shape.index(-1)
|
||||
target_shape[shape_idx] = origin_product // target_product
|
||||
|
||||
# determine changed dim
|
||||
len_diff = len(origin_shape) - len(target_shape)
|
||||
if len_diff == 1:
|
||||
# dim merge
|
||||
dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)]
|
||||
dim_to = [dim_equal.index(False)]
|
||||
dim_from = [dim_equal.index(False), dim_equal.index(False) + 1]
|
||||
self._add_dim(node_idx, -1)
|
||||
elif len_diff == -1:
|
||||
# dim expand
|
||||
dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])]
|
||||
dim_from = [dim_equal.index(False)]
|
||||
dim_to = [dim_equal.index(False), dim_equal.index(False) + 1]
|
||||
self._del_dim(node_idx, -1)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"shape"
|
||||
+ str(origin_shape)
|
||||
+ "and"
|
||||
+ str(target_shape)
|
||||
+ "view not implemented"
|
||||
)
|
||||
|
||||
# get new indice
|
||||
origin_trace = self._find_indice_trace_from_node(origin_node)
|
||||
self._assign_indice_as_input(node, node_idx, origin_node)
|
||||
dim_from.reverse()
|
||||
for i in dim_from:
|
||||
self._del_dim(node_idx, i)
|
||||
for i in dim_to:
|
||||
self._add_dim(node_idx, i)
|
||||
|
||||
# inherit computation
|
||||
compute_log = self._find_compute_trace_from_node(origin_node)
|
||||
for i in dim_from:
|
||||
if origin_trace[i] in compute_log:
|
||||
for j in dim_to:
|
||||
self._mark_computation(node, node_idx, [j])
|
||||
break
|
||||
|
||||
# log view, not used now
|
||||
view_dict = {
|
||||
"idx_from": [origin_trace[i] for i in dim_from],
|
||||
"dim_from": dim_from,
|
||||
"idx_to": [self.indice_trace_list[node_idx]["indice"][i] for i in dim_to],
|
||||
"dim_to": dim_to,
|
||||
}
|
||||
self.indice_view_list[node] = view_dict
|
||||
|
||||
def trace_indice(self):
|
||||
for idx, node in enumerate(self.node_list):
|
||||
if node.op == "placeholder":
|
||||
self._assign_all_indice(node, idx)
|
||||
elif node.op == "call_method":
|
||||
if "transpose" in node.name:
|
||||
self._assign_transpose_indice(node, idx)
|
||||
elif "permute" in node.name:
|
||||
self._assign_permute_indice(node, idx)
|
||||
elif "view" in node.name or "reshape" in node.name:
|
||||
self._assign_view_reshape_indice(node, idx)
|
||||
elif "unsqueeze" in node.name:
|
||||
self._assign_unsqueeze_indice(node, idx)
|
||||
elif any(i in node.name for i in ["to", "contiguous"]):
|
||||
self._assgin_no_change_indice(node, idx)
|
||||
else:
|
||||
raise NotImplementedError(node.name, "method not implemented yet!")
|
||||
elif node.op == "call_function":
|
||||
if "linear" in node.name:
|
||||
self._assign_linear_indice(node, idx)
|
||||
elif "matmul" in node.name:
|
||||
self._assign_matmul_indice(node, idx)
|
||||
elif "softmax" in node.name:
|
||||
self._assign_softmax_indice(node, idx)
|
||||
elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu"]):
|
||||
self._assign_elementwise_indice(node, idx)
|
||||
elif "ones_like" in node.name:
|
||||
self._assign_ones_like_indice(node, idx)
|
||||
elif "dropout" in node.name:
|
||||
self._assign_dropout_indice(node, idx)
|
||||
elif "einsum" in node.name:
|
||||
self._assign_einsum_indice(node, idx)
|
||||
elif "getattr" in node.name:
|
||||
continue # get attr like shape
|
||||
elif "getitem" in node.name:
|
||||
continue # get item in list
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
node.name, "function not implemented yet!"
|
||||
)
|
||||
elif node.op == "call_module":
|
||||
if any(n in node.name for n in ["layernorm", "norm"]):
|
||||
self._assign_layernorm_indice(node, idx)
|
||||
else:
|
||||
raise NotImplementedError(node.name, "module not implemented yet!")
|
||||
elif node.op == "get_attr":
|
||||
self._assign_all_indice(node, idx) # get param
|
||||
elif node.op == "output":
|
||||
continue
|
||||
else:
|
||||
raise NotImplementedError(node.op, "op not implemented yet!")
|
|
@ -0,0 +1,95 @@
|
|||
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||
|
||||
from torch.fx.node import Node
|
||||
|
||||
|
||||
def is_non_compute_node(node):
|
||||
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(
|
||||
i in node.name for i in ["getitem", "getattr"]
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_node_shape(node):
|
||||
if hasattr(node.meta["tensor_meta"], "shape"):
|
||||
return node.meta["tensor_meta"].shape
|
||||
return None
|
||||
|
||||
|
||||
def is_non_compute_node_except_placeholder(node):
|
||||
if any(i in node.op for i in ["get_attr", "output"]) or any(
|
||||
i in node.name for i in ["getitem", "getattr"]
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_non_compute_node_except_placeholder_output(node):
|
||||
if any(i in node.op for i in ["get_attr"]) or any(
|
||||
i in node.name for i in ["getitem", "getattr"]
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def find_idx_by_name(name, nodes_list):
|
||||
for idx, node in enumerate(nodes_list):
|
||||
if node.name == name:
|
||||
return idx
|
||||
raise RuntimeError("name %s not found in node list" % name)
|
||||
|
||||
|
||||
def delete_free_var_from_last_use(user_to_last_uses):
|
||||
for key, value in user_to_last_uses.items():
|
||||
for n in value:
|
||||
if n.op == "placeholder":
|
||||
user_to_last_uses[key].remove(n)
|
||||
|
||||
|
||||
def find_chunk_all_input_nodes(nodes: List[Node]):
|
||||
"""
|
||||
Find non-compute input and output node names.
|
||||
input nodes are nodes used in the list
|
||||
output nodes are nodes will use nodes in the list
|
||||
"""
|
||||
input_nodes = []
|
||||
for node in nodes:
|
||||
for input_node in node._input_nodes.keys():
|
||||
if input_node not in nodes and input_node not in input_nodes:
|
||||
input_nodes.append(input_node)
|
||||
return input_nodes
|
||||
|
||||
|
||||
def find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
|
||||
"""
|
||||
Find non-compute input and output node names.
|
||||
input nodes are nodes used in the list
|
||||
output nodes are nodes will use nodes in the list
|
||||
"""
|
||||
input_nodes = []
|
||||
output_nodes = []
|
||||
|
||||
# if a node has an input node which is not in the node list
|
||||
# we treat that input node as the input of the checkpoint function
|
||||
for node in nodes:
|
||||
for input_node in node._input_nodes.keys():
|
||||
if (
|
||||
input_node not in nodes
|
||||
and input_node not in input_nodes
|
||||
and not is_non_compute_node_except_placeholder(input_node)
|
||||
):
|
||||
input_nodes.append(input_node)
|
||||
|
||||
# if a node has a user node which is not in the node list
|
||||
# we treat that user node as the node receiving the current node output
|
||||
for node in nodes:
|
||||
for output_node in node.users.keys():
|
||||
if (
|
||||
output_node not in nodes
|
||||
and node not in output_nodes
|
||||
and not is_non_compute_node_except_placeholder_output(output_node)
|
||||
):
|
||||
output_nodes.append(node)
|
||||
|
||||
return input_nodes, output_nodes
|
|
@ -0,0 +1,122 @@
|
|||
import time
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from tests.test_autochunk.evoformer.evoformer import evoformer_base
|
||||
from tests.test_autochunk.openfold.evoformer import EvoformerBlock
|
||||
|
||||
|
||||
def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None):
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
|
||||
loop = 3
|
||||
with torch.no_grad():
|
||||
for _ in range(loop // 2 + 1):
|
||||
if chunk_size:
|
||||
model(node, pair, chunk_size)
|
||||
else:
|
||||
model(node, pair)
|
||||
torch.cuda.synchronize()
|
||||
time1 = time.time()
|
||||
for _ in range(loop):
|
||||
if chunk_size:
|
||||
model(node, pair, chunk_size)
|
||||
else:
|
||||
model(node, pair)
|
||||
torch.cuda.synchronize()
|
||||
time2 = time.time()
|
||||
|
||||
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||
print(
|
||||
"%s: time %.4fs, mem %dMB"
|
||||
% (title, (time2 - time1) / loop, new_max_mem - now_mem)
|
||||
)
|
||||
|
||||
|
||||
def _build_autochunk(model, max_memory, node, pair):
|
||||
# trace the module and replace codegen
|
||||
graph = ColoTracer().trace(
|
||||
model,
|
||||
meta_args={
|
||||
"node": node.to(torch.device("meta")),
|
||||
"pair": pair.to(torch.device("meta")),
|
||||
},
|
||||
)
|
||||
|
||||
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
|
||||
interp = MetaInfoProp(gm_prop)
|
||||
interp.propagate(
|
||||
MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
|
||||
)
|
||||
|
||||
# now run it twice to get meta info in graph module, not necessary
|
||||
gm = torch.fx.GraphModule(model, graph)
|
||||
interp = MetaInfoProp(gm)
|
||||
interp.propagate(
|
||||
MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
|
||||
)
|
||||
|
||||
# set code_gen
|
||||
codegen = AutoChunkCodeGen(gm_prop, max_memory, print_mem=False)
|
||||
graph.set_codegen(codegen)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
||||
# print
|
||||
# code = graph.python_code("self").src
|
||||
# print(code)
|
||||
return gm
|
||||
|
||||
|
||||
def _build_openfold():
|
||||
model = EvoformerBlock(
|
||||
c_m=256,
|
||||
c_z=128,
|
||||
c_hidden_msa_att=32,
|
||||
c_hidden_opm=32,
|
||||
c_hidden_mul=128,
|
||||
c_hidden_pair_att=32,
|
||||
no_heads_msa=8,
|
||||
no_heads_pair=4,
|
||||
transition_n=4,
|
||||
msa_dropout=0.15,
|
||||
pair_dropout=0.15,
|
||||
inf=1e4,
|
||||
eps=1e-4,
|
||||
is_multimer=False,
|
||||
).cuda()
|
||||
return model
|
||||
|
||||
|
||||
def benchmark_evoformer():
|
||||
# init data and model
|
||||
msa_len = 256
|
||||
pair_len = 512
|
||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
model = evoformer_base().cuda()
|
||||
|
||||
# build autochunk model
|
||||
# max_memory = 1000 # MB, fit memory mode
|
||||
max_memory = None # min memory mode
|
||||
autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair)
|
||||
|
||||
# build openfold
|
||||
chunk_size = 64
|
||||
openfold = _build_openfold()
|
||||
|
||||
# benchmark
|
||||
_benchmark_evoformer(model, node, pair, "base")
|
||||
_benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size)
|
||||
_benchmark_evoformer(autochunk, node, pair, "autochunk")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark_evoformer()
|
|
@ -0,0 +1,59 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .msa import MSAStack
|
||||
from .ops import OutProductMean
|
||||
from .triangle import PairStack
|
||||
|
||||
|
||||
def print_memory(init_mem, text=None):
|
||||
now_mem = torch.cuda.memory_allocated() / 1024 ** 2 - init_mem
|
||||
max_mem = torch.cuda.max_memory_allocated() / 1024 ** 2 - init_mem
|
||||
print("%s now:%.2f max:%.2f" % ("" if text is None else text, now_mem, max_mem))
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
||||
class EvoformerBlock(nn.Module):
|
||||
|
||||
def __init__(self, d_node, d_pair):
|
||||
super(EvoformerBlock, self).__init__()
|
||||
|
||||
self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15)
|
||||
self.communication = OutProductMean(n_feat=d_node, n_feat_out=d_pair, n_feat_proj=32)
|
||||
self.pair_stack = PairStack(d_pair=d_pair)
|
||||
|
||||
def forward(self, node, pair):
|
||||
node = self.msa_stack(node, pair)
|
||||
pair = pair + self.communication(node)
|
||||
pair = self.pair_stack(pair)
|
||||
return node, pair
|
||||
|
||||
|
||||
class Evoformer(nn.Module):
|
||||
|
||||
def __init__(self, d_node, d_pair):
|
||||
super(Evoformer, self).__init__()
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
for _ in range(1):
|
||||
self.blocks.append(EvoformerBlock(d_node, d_pair))
|
||||
|
||||
def forward(self, node, pair):
|
||||
for b in self.blocks:
|
||||
node, pair = b(node, pair)
|
||||
return node, pair
|
||||
|
||||
|
||||
def evoformer_tiny():
|
||||
return Evoformer(d_node=64, d_pair=32)
|
||||
|
||||
|
||||
def evoformer_base():
|
||||
return Evoformer(d_node=256, d_pair=128)
|
||||
|
||||
|
||||
def evoformer_large():
|
||||
return Evoformer(d_node=512, d_pair=256)
|
||||
|
||||
|
||||
__all__ = ['Evoformer', 'evoformer_base', 'evoformer_large']
|
|
@ -0,0 +1,29 @@
|
|||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def glorot_uniform_af(x, gain=1.0):
|
||||
"""
|
||||
initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different:
|
||||
In PyTorch:
|
||||
[feature_out, feature_in, n_head ...]
|
||||
In Jax:
|
||||
[... n_head, feature_in, feature_out]
|
||||
However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like:
|
||||
[feature_in, n_head, feature_out]
|
||||
|
||||
In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors
|
||||
"""
|
||||
fan_in, fan_out = x.shape[-2:]
|
||||
if len(x.shape) > 2:
|
||||
receptive_field_size = np.prod(x.shape[:-2])
|
||||
fan_in *= receptive_field_size
|
||||
fan_out *= receptive_field_size
|
||||
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
|
||||
dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
||||
|
||||
nn.init.uniform_(x, -dev, dev)
|
||||
|
||||
return x
|
|
@ -0,0 +1,19 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def bias_sigmod_ele(y, bias, z):
|
||||
return torch.sigmoid(y + bias) * z
|
||||
|
||||
|
||||
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor,
|
||||
residual: torch.Tensor, prob: float) -> torch.Tensor:
|
||||
out = (x + bias) * F.dropout(dropmask, p=prob, training=False)
|
||||
out = residual + out
|
||||
return out
|
||||
|
||||
|
||||
def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor,
|
||||
dropout_mask: torch.Tensor, Z_raw: torch.Tensor,
|
||||
prob: float) -> torch.Tensor:
|
||||
return Z_raw + F.dropout(dropout_mask, p=prob, training=True) * (g * (ab + b))
|
|
@ -0,0 +1,95 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from .kernel import bias_dropout_add
|
||||
from .ops import SelfAttention, Transition
|
||||
|
||||
|
||||
class MSARowAttentionWithPairBias(nn.Module):
|
||||
|
||||
def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15):
|
||||
super(MSARowAttentionWithPairBias, self).__init__()
|
||||
self.d_node = d_node
|
||||
self.d_pair = d_pair
|
||||
self.c = c
|
||||
self.n_head = n_head
|
||||
self.p_drop = p_drop
|
||||
|
||||
self.layernormM = LayerNorm(d_node)
|
||||
self.layernormZ = LayerNorm(d_pair)
|
||||
|
||||
_init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]),
|
||||
std=1.0 / math.sqrt(d_pair))
|
||||
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True)
|
||||
|
||||
self.attention = SelfAttention(qkv_dim=d_node,
|
||||
c=c,
|
||||
n_head=n_head,
|
||||
out_dim=d_node,
|
||||
gating=True,
|
||||
last_bias_fuse=True)
|
||||
|
||||
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True)
|
||||
|
||||
def forward(self, M_raw, Z):
|
||||
## Input projections
|
||||
M = self.layernormM(M_raw)
|
||||
Z = self.layernormZ(Z)
|
||||
b = F.linear(Z, self.linear_b_weights)
|
||||
b = b.permute(0, 3, 1, 2)
|
||||
# b = rearrange(b, 'b q k h -> b h q k')
|
||||
|
||||
M = self.attention(M, b)
|
||||
dropout_mask = torch.ones_like(M[:, 0:1, :, :]).to(M.device).to(M.dtype)
|
||||
|
||||
return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop)
|
||||
|
||||
|
||||
class MSAColumnAttention(nn.Module):
|
||||
|
||||
def __init__(self, d_node, c=32, n_head=8):
|
||||
super(MSAColumnAttention, self).__init__()
|
||||
self.d_node = d_node
|
||||
self.c = c
|
||||
self.n_head = n_head
|
||||
|
||||
self.layernormM = LayerNorm(d_node)
|
||||
self.attention = SelfAttention(qkv_dim=d_node,
|
||||
c=c,
|
||||
n_head=n_head,
|
||||
out_dim=d_node,
|
||||
gating=True)
|
||||
|
||||
def forward(self, M_raw):
|
||||
M = M_raw.transpose(-2, -3)
|
||||
M = self.layernormM(M)
|
||||
|
||||
M = self.attention(M)
|
||||
|
||||
M = M.transpose(-2, -3)
|
||||
return M_raw + M
|
||||
|
||||
|
||||
class MSAStack(nn.Module):
|
||||
|
||||
def __init__(self, d_node, d_pair, p_drop=0.15):
|
||||
super(MSAStack, self).__init__()
|
||||
|
||||
self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node,
|
||||
d_pair=d_pair,
|
||||
p_drop=p_drop)
|
||||
|
||||
self.MSAColumnAttention = MSAColumnAttention(d_node=d_node)
|
||||
self.MSATransition = Transition(d=d_node)
|
||||
|
||||
def forward(self, node, pair):
|
||||
node = self.MSARowAttentionWithPairBias(node, pair)
|
||||
node = self.MSAColumnAttention(node)
|
||||
node = self.MSATransition(node)
|
||||
|
||||
return node
|
|
@ -0,0 +1,176 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from .initializer import glorot_uniform_af
|
||||
from .kernel import bias_sigmod_ele
|
||||
|
||||
|
||||
class DropoutRowwise(nn.Module):
|
||||
|
||||
def __init__(self, p):
|
||||
super(DropoutRowwise, self).__init__()
|
||||
self.p = p
|
||||
self.dropout = nn.Dropout(p=p)
|
||||
|
||||
def forward(self, x):
|
||||
dropout_mask = torch.ones_like(x[:, 0:1, :, :])
|
||||
dropout_mask = self.dropout(dropout_mask)
|
||||
return dropout_mask * x
|
||||
|
||||
|
||||
class DropoutColumnwise(nn.Module):
|
||||
|
||||
def __init__(self, p):
|
||||
super(DropoutColumnwise, self).__init__()
|
||||
self.p = p
|
||||
self.dropout = nn.Dropout(p=p)
|
||||
|
||||
def forward(self, x):
|
||||
dropout_mask = torch.ones_like(x[:, :, 0:1, :])
|
||||
dropout_mask = self.dropout(dropout_mask)
|
||||
return dropout_mask * x
|
||||
|
||||
|
||||
class Transition(nn.Module):
|
||||
|
||||
def __init__(self, d, n=4):
|
||||
super(Transition, self).__init__()
|
||||
self.norm = LayerNorm(d)
|
||||
self.linear1 = Linear(d, n * d, initializer='relu')
|
||||
self.linear2 = Linear(n * d, d, initializer='zeros')
|
||||
|
||||
def forward(self, src):
|
||||
x = self.norm(src)
|
||||
x = self.linear2(F.relu(self.linear1(x)))
|
||||
return src + x
|
||||
|
||||
|
||||
class OutProductMean(nn.Module):
|
||||
|
||||
def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32):
|
||||
super(OutProductMean, self).__init__()
|
||||
|
||||
self.layernormM = LayerNorm(n_feat)
|
||||
self.linear_a = Linear(n_feat, n_feat_proj)
|
||||
self.linear_b = Linear(n_feat, n_feat_proj)
|
||||
|
||||
self.o_linear = Linear(n_feat_proj * n_feat_proj,
|
||||
n_feat_out,
|
||||
initializer='zero',
|
||||
use_bias=True)
|
||||
|
||||
def forward(self, M):
|
||||
M = self.layernormM(M)
|
||||
left_act = self.linear_a(M)
|
||||
right_act = self.linear_b(M)
|
||||
|
||||
o = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous()
|
||||
# O = rearrange(O, 'b i j d e -> b i j (d e)')
|
||||
o = o.reshape(o.shape[0], o.shape[1], o.shape[2], -1)
|
||||
Z = self.o_linear(o)
|
||||
|
||||
return Z
|
||||
|
||||
|
||||
class Linear(nn.Linear):
|
||||
"""
|
||||
A Linear layer with built-in nonstandard initializations. Called just
|
||||
like torch.nn.Linear.
|
||||
Implements the initializers in 1.11.4, plus some additional ones found
|
||||
in the code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_in: int,
|
||||
feature_out: int,
|
||||
initializer: str = 'linear',
|
||||
use_bias: bool = True,
|
||||
bias_init: float = 0.,
|
||||
):
|
||||
super(Linear, self).__init__(feature_in, feature_out, bias=use_bias)
|
||||
|
||||
self.use_bias = use_bias
|
||||
if initializer == 'linear':
|
||||
glorot_uniform_af(self.weight, gain=1.0)
|
||||
elif initializer == 'relu':
|
||||
glorot_uniform_af(self.weight, gain=2.0)
|
||||
elif initializer == 'zeros':
|
||||
nn.init.zeros_(self.weight)
|
||||
if self.use_bias:
|
||||
with torch.no_grad():
|
||||
self.bias.fill_(bias_init)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
"""
|
||||
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
|
||||
"""
|
||||
|
||||
def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False):
|
||||
super(SelfAttention, self).__init__()
|
||||
self.qkv_dim = qkv_dim
|
||||
self.c = c
|
||||
self.n_head = n_head
|
||||
self.out_dim = out_dim
|
||||
self.gating = gating
|
||||
self.last_bias_fuse = last_bias_fuse
|
||||
|
||||
self.scaling = self.c**(-0.5)
|
||||
|
||||
# self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear')
|
||||
self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
|
||||
self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
|
||||
self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
|
||||
|
||||
if gating:
|
||||
self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,)))
|
||||
self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False)
|
||||
|
||||
self.o_linear = Linear(n_head * c,
|
||||
out_dim,
|
||||
initializer='zero',
|
||||
use_bias=(not last_bias_fuse))
|
||||
|
||||
def forward(self, in_data, nonbatched_bias=None):
|
||||
"""
|
||||
:param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim]
|
||||
:param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv]
|
||||
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
|
||||
"""
|
||||
|
||||
# qkv = self.to_qkv(in_data).chunk(3, dim=-1)
|
||||
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
|
||||
|
||||
q = self.to_q(in_data)
|
||||
k = self.to_k(in_data)
|
||||
v = self.to_v(in_data)
|
||||
|
||||
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head),
|
||||
# [q, k, v])
|
||||
q, k, v = map(lambda t: t.view(t.shape[0], t.shape[1], t.shape[2], self.n_head, -1).permute(0, 1, 3, 2, 4),
|
||||
[q, k, v])
|
||||
|
||||
q = q * self.scaling
|
||||
|
||||
logits = torch.matmul(q, k.transpose(-1, -2))
|
||||
|
||||
if nonbatched_bias is not None:
|
||||
logits += nonbatched_bias.unsqueeze(1)
|
||||
weights = torch.softmax(logits, dim=-1)
|
||||
# weights = softmax(logits)
|
||||
|
||||
weighted_avg = torch.matmul(weights, v)
|
||||
# weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
|
||||
weighted_avg = weighted_avg.permute(0, 1, 3, 2, 4)
|
||||
weighted_avg = weighted_avg.reshape(weighted_avg.shape[0], weighted_avg.shape[1], weighted_avg.shape[2], -1)
|
||||
|
||||
if self.gating:
|
||||
gate_values = self.gating_linear(in_data)
|
||||
weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg)
|
||||
|
||||
output = self.o_linear(weighted_avg)
|
||||
return output
|
|
@ -0,0 +1,192 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from .kernel import bias_dropout_add, bias_ele_dropout_residual
|
||||
from .ops import Linear, SelfAttention, Transition
|
||||
|
||||
|
||||
def permute_final_dims(tensor, inds):
|
||||
zero_index = -1 * len(inds)
|
||||
first_inds = list(range(len(tensor.shape[:zero_index])))
|
||||
return tensor.permute(first_inds + [zero_index + i for i in inds])
|
||||
|
||||
|
||||
class TriangleMultiplicationOutgoing(nn.Module):
|
||||
|
||||
def __init__(self, d_pair, p_drop, c=128):
|
||||
super(TriangleMultiplicationOutgoing, self).__init__()
|
||||
self.d_pair = d_pair
|
||||
self.c = c
|
||||
|
||||
self.layernorm1 = LayerNorm(d_pair)
|
||||
self.left_projection = Linear(d_pair, c)
|
||||
self.right_projection = Linear(d_pair, c)
|
||||
self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||
self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||
|
||||
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
|
||||
self.layernorm2 = LayerNorm(c)
|
||||
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
|
||||
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||
|
||||
self.p_drop = p_drop
|
||||
|
||||
def forward(self, Z_raw):
|
||||
Z = self.layernorm1(Z_raw)
|
||||
left_proj_act = self.left_projection(Z)
|
||||
right_proj_act = self.right_projection(Z)
|
||||
|
||||
left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z))
|
||||
right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z))
|
||||
|
||||
g = torch.sigmoid(self.output_gate(Z))
|
||||
# p = torch.matmul(
|
||||
# permute_final_dims(left_proj_act, (2, 0, 1)),
|
||||
# permute_final_dims(right_proj_act, (2, 1, 0)),
|
||||
# )
|
||||
# ab = permute_final_dims(p, (1, 2, 0))
|
||||
|
||||
ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act)
|
||||
ab = self.output_projection(self.layernorm2(ab))
|
||||
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
|
||||
return bias_ele_dropout_residual(ab,
|
||||
self.output_bias,
|
||||
g,
|
||||
dropout_mask,
|
||||
Z_raw,
|
||||
prob=self.p_drop)
|
||||
|
||||
|
||||
class TriangleMultiplicationIncoming(nn.Module):
|
||||
|
||||
def __init__(self, d_pair, p_drop, c=128):
|
||||
super(TriangleMultiplicationIncoming, self).__init__()
|
||||
self.d_pair = d_pair
|
||||
self.c = c
|
||||
|
||||
self.layernorm1 = LayerNorm(d_pair)
|
||||
self.left_projection = Linear(d_pair, c)
|
||||
self.right_projection = Linear(d_pair, c)
|
||||
self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||
self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||
|
||||
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
|
||||
self.layernorm2 = LayerNorm(c)
|
||||
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
|
||||
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||
|
||||
self.p_drop = p_drop
|
||||
|
||||
def forward(self, Z_raw):
|
||||
Z = self.layernorm1(Z_raw)
|
||||
left_proj_act = self.left_projection(Z)
|
||||
right_proj_act = self.right_projection(Z)
|
||||
|
||||
left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z))
|
||||
right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z))
|
||||
|
||||
g = torch.sigmoid(self.output_gate(Z))
|
||||
# p = torch.matmul(
|
||||
# permute_final_dims(left_proj_act, (2, 1, 0)),
|
||||
# permute_final_dims(right_proj_act, (2, 0, 1)),
|
||||
# )
|
||||
# ab = permute_final_dims(p, (1, 2, 0))
|
||||
|
||||
ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act)
|
||||
ab = self.output_projection(self.layernorm2(ab))
|
||||
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
|
||||
return bias_ele_dropout_residual(ab,
|
||||
self.output_bias,
|
||||
g,
|
||||
dropout_mask,
|
||||
Z_raw,
|
||||
prob=self.p_drop)
|
||||
|
||||
|
||||
class TriangleAttentionStartingNode(nn.Module):
|
||||
|
||||
def __init__(self, d_pair, p_drop, c=32, n_head=4):
|
||||
super(TriangleAttentionStartingNode, self).__init__()
|
||||
self.d_pair = d_pair
|
||||
self.c = c
|
||||
self.n_head = n_head
|
||||
self.p_drop = p_drop
|
||||
|
||||
self.layernorm1 = LayerNorm(d_pair)
|
||||
_init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]),
|
||||
std=1.0 / math.sqrt(d_pair))
|
||||
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
|
||||
self.attention = SelfAttention(qkv_dim=d_pair,
|
||||
c=c,
|
||||
n_head=n_head,
|
||||
out_dim=d_pair,
|
||||
gating=True,
|
||||
last_bias_fuse=True)
|
||||
|
||||
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||
|
||||
def forward(self, Z_raw):
|
||||
Z = self.layernorm1(Z_raw)
|
||||
b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights)
|
||||
|
||||
Z = self.attention(Z, b)
|
||||
|
||||
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
|
||||
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
|
||||
|
||||
|
||||
class TriangleAttentionEndingNode(nn.Module):
|
||||
|
||||
def __init__(self, d_pair, p_drop, c=32, n_head=4):
|
||||
super(TriangleAttentionEndingNode, self).__init__()
|
||||
self.d_pair = d_pair
|
||||
self.c = c
|
||||
self.n_head = n_head
|
||||
self.p_drop = p_drop
|
||||
|
||||
self.layernorm1 = LayerNorm(d_pair)
|
||||
_init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]),
|
||||
std=1.0 / math.sqrt(d_pair))
|
||||
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
|
||||
self.attention = SelfAttention(qkv_dim=d_pair,
|
||||
c=c,
|
||||
n_head=n_head,
|
||||
out_dim=d_pair,
|
||||
gating=True,
|
||||
last_bias_fuse=True)
|
||||
|
||||
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||
|
||||
def forward(self, Z_raw):
|
||||
Z = Z_raw.transpose(-2, -3)
|
||||
Z = self.layernorm1(Z)
|
||||
b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights)
|
||||
|
||||
Z = self.attention(Z, b)
|
||||
|
||||
Z = Z.transpose(-2, -3)
|
||||
dropout_mask = torch.ones_like(Z[:, :, 0:1, :]).to(Z.device).to(Z.dtype)
|
||||
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
|
||||
|
||||
|
||||
class PairStack(nn.Module):
|
||||
|
||||
def __init__(self, d_pair, p_drop=0.25):
|
||||
super(PairStack, self).__init__()
|
||||
|
||||
self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, p_drop=p_drop)
|
||||
self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, p_drop=p_drop)
|
||||
self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, p_drop=p_drop)
|
||||
self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, p_drop=p_drop)
|
||||
self.PairTransition = Transition(d=d_pair)
|
||||
|
||||
def forward(self, pair):
|
||||
pair = self.TriangleMultiplicationOutgoing(pair)
|
||||
pair = self.TriangleMultiplicationIncoming(pair)
|
||||
pair = self.TriangleAttentionStartingNode(pair)
|
||||
pair = self.TriangleAttentionEndingNode(pair)
|
||||
pair = self.PairTransition(pair)
|
||||
return pair
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright 2021 AlQuraishi Laboratory
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from typing import Any, Tuple, List, Callable, Optional
|
||||
|
||||
|
||||
BLOCK_ARG = Any
|
||||
BLOCK_ARGS = List[BLOCK_ARG]
|
||||
|
||||
|
||||
def get_checkpoint_fn():
|
||||
checkpoint = torch.utils.checkpoint.checkpoint
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def checkpoint_blocks(
|
||||
blocks: List[Callable],
|
||||
args: BLOCK_ARGS,
|
||||
blocks_per_ckpt: Optional[int],
|
||||
) -> BLOCK_ARGS:
|
||||
"""
|
||||
Chunk a list of blocks and run each chunk with activation
|
||||
checkpointing. We define a "block" as a callable whose only inputs are
|
||||
the outputs of the previous block.
|
||||
|
||||
Implements Subsection 1.11.8
|
||||
|
||||
Args:
|
||||
blocks:
|
||||
List of blocks
|
||||
args:
|
||||
Tuple of arguments for the first block.
|
||||
blocks_per_ckpt:
|
||||
Size of each chunk. A higher value corresponds to fewer
|
||||
checkpoints, and trades memory for speed. If None, no checkpointing
|
||||
is performed.
|
||||
Returns:
|
||||
The output of the final block
|
||||
"""
|
||||
def wrap(a):
|
||||
return (a,) if type(a) is not tuple else a
|
||||
|
||||
def exec(b, a):
|
||||
for block in b:
|
||||
a = wrap(block(*a))
|
||||
return a
|
||||
|
||||
def chunker(s, e):
|
||||
def exec_sliced(*a):
|
||||
return exec(blocks[s:e], a)
|
||||
|
||||
return exec_sliced
|
||||
|
||||
# Avoids mishaps when the blocks take just one argument
|
||||
args = wrap(args)
|
||||
|
||||
if blocks_per_ckpt is None:
|
||||
return exec(blocks, args)
|
||||
elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks):
|
||||
raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)")
|
||||
|
||||
checkpoint = get_checkpoint_fn()
|
||||
|
||||
for s in range(0, len(blocks), blocks_per_ckpt):
|
||||
e = s + blocks_per_ckpt
|
||||
args = checkpoint(chunker(s, e), *args)
|
||||
args = wrap(args)
|
||||
|
||||
return args
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2021 AlQuraishi Laboratory
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partialmethod
|
||||
from typing import Union, List
|
||||
|
||||
|
||||
class Dropout(nn.Module):
|
||||
"""
|
||||
Implementation of dropout with the ability to share the dropout mask
|
||||
along a particular dimension.
|
||||
|
||||
If not in training mode, this module computes the identity function.
|
||||
"""
|
||||
|
||||
def __init__(self, r: float, batch_dim: Union[int, List[int]]):
|
||||
"""
|
||||
Args:
|
||||
r:
|
||||
Dropout rate
|
||||
batch_dim:
|
||||
Dimension(s) along which the dropout mask is shared
|
||||
"""
|
||||
super(Dropout, self).__init__()
|
||||
|
||||
self.r = r
|
||||
if type(batch_dim) == int:
|
||||
batch_dim = [batch_dim]
|
||||
self.batch_dim = batch_dim
|
||||
self.dropout = nn.Dropout(self.r)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
Tensor to which dropout is applied. Can have any shape
|
||||
compatible with self.batch_dim
|
||||
"""
|
||||
shape = list(x.shape)
|
||||
if self.batch_dim is not None:
|
||||
for bd in self.batch_dim:
|
||||
shape[bd] = 1
|
||||
mask = x.new_ones(shape)
|
||||
mask = self.dropout(mask)
|
||||
x *= mask
|
||||
return x
|
||||
|
||||
|
||||
class DropoutRowwise(Dropout):
|
||||
"""
|
||||
Convenience class for rowwise dropout as described in subsection
|
||||
1.11.6.
|
||||
"""
|
||||
|
||||
__init__ = partialmethod(Dropout.__init__, batch_dim=-3)
|
||||
|
||||
|
||||
class DropoutColumnwise(Dropout):
|
||||
"""
|
||||
Convenience class for columnwise dropout as described in subsection
|
||||
1.11.6.
|
||||
"""
|
||||
|
||||
__init__ = partialmethod(Dropout.__init__, batch_dim=-2)
|
|
@ -0,0 +1,431 @@
|
|||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Tuple, Optional
|
||||
from functools import partial
|
||||
|
||||
from .primitives import Linear, LayerNorm
|
||||
from .dropout import DropoutRowwise, DropoutColumnwise
|
||||
from .msa import (
|
||||
MSARowAttentionWithPairBias,
|
||||
MSAColumnAttention,
|
||||
MSAColumnGlobalAttention,
|
||||
)
|
||||
from .outer_product_mean import OuterProductMean
|
||||
from .pair_transition import PairTransition
|
||||
from .triangular_attention import (
|
||||
TriangleAttentionStartingNode,
|
||||
TriangleAttentionEndingNode,
|
||||
)
|
||||
from .triangular_multiplicative_update import (
|
||||
TriangleMultiplicationOutgoing,
|
||||
TriangleMultiplicationIncoming,
|
||||
)
|
||||
from .checkpointing import checkpoint_blocks, get_checkpoint_fn
|
||||
from .tensor_utils import chunk_layer
|
||||
|
||||
|
||||
class MSATransition(nn.Module):
|
||||
"""
|
||||
Feed-forward network applied to MSA activations after attention.
|
||||
|
||||
Implements Algorithm 9
|
||||
"""
|
||||
def __init__(self, c_m, n):
|
||||
"""
|
||||
Args:
|
||||
c_m:
|
||||
MSA channel dimension
|
||||
n:
|
||||
Factor multiplied to c_m to obtain the hidden channel
|
||||
dimension
|
||||
"""
|
||||
super(MSATransition, self).__init__()
|
||||
|
||||
self.c_m = c_m
|
||||
self.n = n
|
||||
|
||||
self.layer_norm = LayerNorm(self.c_m)
|
||||
self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu")
|
||||
self.relu = nn.ReLU()
|
||||
self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
|
||||
|
||||
def _transition(self, m, mask):
|
||||
m = self.linear_1(m)
|
||||
m = self.relu(m)
|
||||
m = self.linear_2(m) * mask
|
||||
return m
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(self,
|
||||
m: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
chunk_size: int,
|
||||
) -> torch.Tensor:
|
||||
return chunk_layer(
|
||||
self._transition,
|
||||
{"m": m, "mask": mask},
|
||||
chunk_size=chunk_size,
|
||||
no_batch_dims=len(m.shape[:-2]),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
m: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
m:
|
||||
[*, N_seq, N_res, C_m] MSA activation
|
||||
mask:
|
||||
[*, N_seq, N_res, C_m] MSA mask
|
||||
Returns:
|
||||
m:
|
||||
[*, N_seq, N_res, C_m] MSA activation update
|
||||
"""
|
||||
|
||||
# DISCREPANCY: DeepMind forgets to apply the MSA mask here.
|
||||
if mask is None:
|
||||
mask = m.new_ones(m.shape[:-1])
|
||||
|
||||
# [*, N_seq, N_res, 1]
|
||||
mask = mask.unsqueeze(-1)
|
||||
|
||||
m = self.layer_norm(m)
|
||||
|
||||
if chunk_size is not None:
|
||||
m = self._chunk(m, mask, chunk_size)
|
||||
else:
|
||||
m = self._transition(m, mask)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
class EvoformerBlockCore(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
c_m: int,
|
||||
c_z: int,
|
||||
c_hidden_opm: int,
|
||||
c_hidden_mul: int,
|
||||
c_hidden_pair_att: int,
|
||||
no_heads_msa: int,
|
||||
no_heads_pair: int,
|
||||
transition_n: int,
|
||||
pair_dropout: float,
|
||||
inf: float,
|
||||
eps: float,
|
||||
_is_extra_msa_stack: bool = False,
|
||||
is_multimer: bool = False,
|
||||
):
|
||||
super(EvoformerBlockCore, self).__init__()
|
||||
self.is_multimer = is_multimer
|
||||
self.msa_transition = MSATransition(
|
||||
c_m=c_m,
|
||||
n=transition_n,
|
||||
)
|
||||
|
||||
self.outer_product_mean = OuterProductMean(
|
||||
c_m,
|
||||
c_z,
|
||||
c_hidden_opm,
|
||||
)
|
||||
|
||||
self.tri_mul_out = TriangleMultiplicationOutgoing(
|
||||
c_z,
|
||||
c_hidden_mul,
|
||||
)
|
||||
self.tri_mul_in = TriangleMultiplicationIncoming(
|
||||
c_z,
|
||||
c_hidden_mul,
|
||||
)
|
||||
|
||||
self.tri_att_start = TriangleAttentionStartingNode(
|
||||
c_z,
|
||||
c_hidden_pair_att,
|
||||
no_heads_pair,
|
||||
inf=inf,
|
||||
)
|
||||
self.tri_att_end = TriangleAttentionEndingNode(
|
||||
c_z,
|
||||
c_hidden_pair_att,
|
||||
no_heads_pair,
|
||||
inf=inf,
|
||||
)
|
||||
|
||||
self.pair_transition = PairTransition(
|
||||
c_z,
|
||||
transition_n,
|
||||
)
|
||||
|
||||
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
|
||||
self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
m: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# DeepMind doesn't mask these transitions in the source, so _mask_trans
|
||||
# should be disabled to better approximate the exact activations of
|
||||
# the original.
|
||||
|
||||
m = m + self.msa_transition(
|
||||
m, chunk_size=chunk_size
|
||||
)
|
||||
z = z + self.outer_product_mean(
|
||||
m, chunk_size=chunk_size
|
||||
)
|
||||
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z))
|
||||
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z))
|
||||
z = z + self.ps_dropout_row_layer(
|
||||
self.tri_att_start(z, chunk_size=chunk_size)
|
||||
)
|
||||
z = z + self.ps_dropout_col_layer(
|
||||
self.tri_att_end(z, chunk_size=chunk_size)
|
||||
)
|
||||
z = z + self.pair_transition(
|
||||
z, chunk_size=chunk_size
|
||||
)
|
||||
|
||||
return m, z
|
||||
|
||||
|
||||
class EvoformerBlock(nn.Module):
|
||||
def __init__(self,
|
||||
c_m: int,
|
||||
c_z: int,
|
||||
c_hidden_msa_att: int,
|
||||
c_hidden_opm: int,
|
||||
c_hidden_mul: int,
|
||||
c_hidden_pair_att: int,
|
||||
no_heads_msa: int,
|
||||
no_heads_pair: int,
|
||||
transition_n: int,
|
||||
msa_dropout: float,
|
||||
pair_dropout: float,
|
||||
inf: float,
|
||||
eps: float,
|
||||
is_multimer: bool,
|
||||
):
|
||||
super(EvoformerBlock, self).__init__()
|
||||
|
||||
self.msa_att_row = MSARowAttentionWithPairBias(
|
||||
c_m=c_m,
|
||||
c_z=c_z,
|
||||
c_hidden=c_hidden_msa_att,
|
||||
no_heads=no_heads_msa,
|
||||
inf=inf,
|
||||
)
|
||||
|
||||
self.msa_att_col = MSAColumnAttention(
|
||||
c_m,
|
||||
c_hidden_msa_att,
|
||||
no_heads_msa,
|
||||
inf=inf,
|
||||
)
|
||||
|
||||
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
|
||||
|
||||
self.core = EvoformerBlockCore(
|
||||
c_m=c_m,
|
||||
c_z=c_z,
|
||||
c_hidden_opm=c_hidden_opm,
|
||||
c_hidden_mul=c_hidden_mul,
|
||||
c_hidden_pair_att=c_hidden_pair_att,
|
||||
no_heads_msa=no_heads_msa,
|
||||
no_heads_pair=no_heads_pair,
|
||||
transition_n=transition_n,
|
||||
pair_dropout=pair_dropout,
|
||||
inf=inf,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
self.outer_product_mean = OuterProductMean(
|
||||
c_m,
|
||||
c_z,
|
||||
c_hidden_opm,
|
||||
)
|
||||
self.is_multimer = is_multimer
|
||||
|
||||
def forward(self,
|
||||
m: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
m = m + self.msa_dropout_layer(
|
||||
self.msa_att_row(m, z=z, chunk_size=chunk_size)
|
||||
)
|
||||
m = m + self.msa_att_col(m, chunk_size=chunk_size)
|
||||
m, z = self.core(
|
||||
m,
|
||||
z,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
return m, z
|
||||
|
||||
|
||||
class EvoformerStack(nn.Module):
|
||||
"""
|
||||
Main Evoformer trunk.
|
||||
|
||||
Implements Algorithm 6.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c_m: int,
|
||||
c_z: int,
|
||||
c_hidden_msa_att: int,
|
||||
c_hidden_opm: int,
|
||||
c_hidden_mul: int,
|
||||
c_hidden_pair_att: int,
|
||||
c_s: int,
|
||||
no_heads_msa: int,
|
||||
no_heads_pair: int,
|
||||
no_blocks: int,
|
||||
transition_n: int,
|
||||
msa_dropout: float,
|
||||
pair_dropout: float,
|
||||
blocks_per_ckpt: int,
|
||||
inf: float,
|
||||
eps: float,
|
||||
clear_cache_between_blocks: bool = False,
|
||||
is_multimer: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
c_m:
|
||||
MSA channel dimension
|
||||
c_z:
|
||||
Pair channel dimension
|
||||
c_hidden_msa_att:
|
||||
Hidden dimension in MSA attention
|
||||
c_hidden_opm:
|
||||
Hidden dimension in outer product mean module
|
||||
c_hidden_mul:
|
||||
Hidden dimension in multiplicative updates
|
||||
c_hidden_pair_att:
|
||||
Hidden dimension in triangular attention
|
||||
c_s:
|
||||
Channel dimension of the output "single" embedding
|
||||
no_heads_msa:
|
||||
Number of heads used for MSA attention
|
||||
no_heads_pair:
|
||||
Number of heads used for pair attention
|
||||
no_blocks:
|
||||
Number of Evoformer blocks in the stack
|
||||
transition_n:
|
||||
Factor by which to multiply c_m to obtain the MSATransition
|
||||
hidden dimension
|
||||
msa_dropout:
|
||||
Dropout rate for MSA activations
|
||||
pair_dropout:
|
||||
Dropout used for pair activations
|
||||
blocks_per_ckpt:
|
||||
Number of Evoformer blocks in each activation checkpoint
|
||||
clear_cache_between_blocks:
|
||||
Whether to clear CUDA's GPU memory cache between blocks of the
|
||||
stack. Slows down each block but can reduce fragmentation
|
||||
"""
|
||||
super(EvoformerStack, self).__init__()
|
||||
|
||||
self.blocks_per_ckpt = blocks_per_ckpt
|
||||
self.clear_cache_between_blocks = clear_cache_between_blocks
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
|
||||
for _ in range(no_blocks):
|
||||
block = EvoformerBlock(
|
||||
c_m=c_m,
|
||||
c_z=c_z,
|
||||
c_hidden_msa_att=c_hidden_msa_att,
|
||||
c_hidden_opm=c_hidden_opm,
|
||||
c_hidden_mul=c_hidden_mul,
|
||||
c_hidden_pair_att=c_hidden_pair_att,
|
||||
no_heads_msa=no_heads_msa,
|
||||
no_heads_pair=no_heads_pair,
|
||||
transition_n=transition_n,
|
||||
msa_dropout=msa_dropout,
|
||||
pair_dropout=pair_dropout,
|
||||
inf=inf,
|
||||
eps=eps,
|
||||
is_multimer=is_multimer,
|
||||
)
|
||||
self.blocks.append(block)
|
||||
|
||||
self.linear = Linear(c_m, c_s)
|
||||
|
||||
def forward(self,
|
||||
m: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
msa_mask: torch.Tensor,
|
||||
pair_mask: torch.Tensor,
|
||||
chunk_size: int,
|
||||
_mask_trans: bool = True,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
m:
|
||||
[*, N_seq, N_res, C_m] MSA embedding
|
||||
z:
|
||||
[*, N_res, N_res, C_z] pair embedding
|
||||
msa_mask:
|
||||
[*, N_seq, N_res] MSA mask
|
||||
pair_mask:
|
||||
[*, N_res, N_res] pair mask
|
||||
Returns:
|
||||
m:
|
||||
[*, N_seq, N_res, C_m] MSA embedding
|
||||
z:
|
||||
[*, N_res, N_res, C_z] pair embedding
|
||||
s:
|
||||
[*, N_res, C_s] single embedding (or None if extra MSA stack)
|
||||
"""
|
||||
blocks = [
|
||||
partial(
|
||||
b,
|
||||
msa_mask=msa_mask,
|
||||
pair_mask=pair_mask,
|
||||
chunk_size=chunk_size,
|
||||
_mask_trans=_mask_trans,
|
||||
)
|
||||
for b in self.blocks
|
||||
]
|
||||
|
||||
if(self.clear_cache_between_blocks):
|
||||
def block_with_cache_clear(block, *args):
|
||||
torch.cuda.empty_cache()
|
||||
return block(*args)
|
||||
|
||||
blocks = [partial(block_with_cache_clear, b) for b in blocks]
|
||||
|
||||
m, z = checkpoint_blocks(
|
||||
blocks,
|
||||
args=(m, z),
|
||||
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
|
||||
)
|
||||
|
||||
s = self.linear(m[..., 0, :, :])
|
||||
|
||||
return m, z, s
|
|
@ -0,0 +1,331 @@
|
|||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from .primitives import (
|
||||
Linear,
|
||||
LayerNorm,
|
||||
Attention,
|
||||
GlobalAttention,
|
||||
_attention_chunked_trainable,
|
||||
)
|
||||
from .checkpointing import get_checkpoint_fn
|
||||
from .tensor_utils import (
|
||||
chunk_layer,
|
||||
permute_final_dims,
|
||||
flatten_final_dims,
|
||||
)
|
||||
|
||||
|
||||
class MSAAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
c_in,
|
||||
c_hidden,
|
||||
no_heads,
|
||||
pair_bias=False,
|
||||
c_z=None,
|
||||
inf=1e9,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
c_in:
|
||||
Input channel dimension
|
||||
c_hidden:
|
||||
Per-head hidden channel dimension
|
||||
no_heads:
|
||||
Number of attention heads
|
||||
pair_bias:
|
||||
Whether to use pair embedding bias
|
||||
c_z:
|
||||
Pair embedding channel dimension. Ignored unless pair_bias
|
||||
is true
|
||||
inf:
|
||||
A large number to be used in computing the attention mask
|
||||
"""
|
||||
super(MSAAttention, self).__init__()
|
||||
|
||||
self.c_in = c_in
|
||||
self.c_hidden = c_hidden
|
||||
self.no_heads = no_heads
|
||||
self.pair_bias = pair_bias
|
||||
self.c_z = c_z
|
||||
self.inf = inf
|
||||
|
||||
self.layer_norm_m = LayerNorm(self.c_in)
|
||||
|
||||
self.layer_norm_z = None
|
||||
self.linear_z = None
|
||||
if self.pair_bias:
|
||||
self.layer_norm_z = LayerNorm(self.c_z)
|
||||
self.linear_z = Linear(
|
||||
self.c_z, self.no_heads, bias=False, init="normal"
|
||||
)
|
||||
|
||||
self.mha = Attention(
|
||||
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(self,
|
||||
m: torch.Tensor,
|
||||
biases: List[torch.Tensor],
|
||||
chunk_size: int,
|
||||
) -> torch.Tensor:
|
||||
return chunk_layer(
|
||||
self.mha,
|
||||
{"q_x": m, "kv_x": m, "biases": biases},
|
||||
chunk_size=chunk_size,
|
||||
no_batch_dims=len(m.shape[:-2]),
|
||||
)
|
||||
|
||||
def _prep_inputs(self,
|
||||
m: torch.Tensor,
|
||||
z: Optional[torch.Tensor],
|
||||
mask: Optional[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# [*, N_seq, N_res, C_m]
|
||||
m = self.layer_norm_m(m)
|
||||
|
||||
n_seq, n_res = m.shape[-3:-1]
|
||||
if mask is None:
|
||||
# [*, N_seq, N_res]
|
||||
mask = m.new_ones(
|
||||
m.shape[:-3] + (n_seq, n_res),
|
||||
)
|
||||
|
||||
# [*, N_seq, 1, 1, N_res]
|
||||
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
|
||||
|
||||
# This step simply returns a larger view of the bias, and does not
|
||||
# consume additional memory.
|
||||
# [*, N_seq, no_heads, N_res, N_res]
|
||||
#bias = bias.expand(
|
||||
# ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
|
||||
#)
|
||||
|
||||
if (self.pair_bias and
|
||||
z is not None and # For the
|
||||
self.layer_norm_z is not None and # benefit of
|
||||
self.linear_z is not None # TorchScript
|
||||
):
|
||||
# [*, N_res, N_res, C_z]
|
||||
z = self.layer_norm_z(z)
|
||||
|
||||
# [*, N_res, N_res, no_heads]
|
||||
z = self.linear_z(z)
|
||||
|
||||
# [*, 1, no_heads, N_res, N_res]
|
||||
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
|
||||
|
||||
return m, mask_bias, z
|
||||
|
||||
|
||||
def forward(self,
|
||||
m: torch.Tensor,
|
||||
z: Optional[torch.Tensor] = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
_chunk_logits: Optional[int] = None,
|
||||
_checkpoint_chunks: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
m:
|
||||
[*, N_seq, N_res, C_m] MSA embedding
|
||||
z:
|
||||
[*, N_res, N_res, C_z] pair embedding. Required only if
|
||||
pair_bias is True
|
||||
mask:
|
||||
[*, N_seq, N_res] MSA mask
|
||||
chunk_size:
|
||||
Size of chunks into which the inputs are split along their
|
||||
batch dimensions. A low value decreases memory overhead at the
|
||||
cost of slower execution. Chunking is not performed by default.
|
||||
|
||||
"""
|
||||
m, mask_bias, z = self._prep_inputs(m, z, mask)
|
||||
|
||||
biases = [mask_bias]
|
||||
if(z is not None):
|
||||
biases.append(z)
|
||||
|
||||
if chunk_size is not None:
|
||||
m = self._chunk(m, biases, chunk_size)
|
||||
else:
|
||||
m = self.mha(
|
||||
q_x=m,
|
||||
kv_x=m,
|
||||
biases=biases
|
||||
)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
class MSARowAttentionWithPairBias(MSAAttention):
|
||||
"""
|
||||
Implements Algorithm 7.
|
||||
"""
|
||||
|
||||
def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9):
|
||||
"""
|
||||
Args:
|
||||
c_m:
|
||||
Input channel dimension
|
||||
c_z:
|
||||
Pair embedding channel dimension
|
||||
c_hidden:
|
||||
Per-head hidden channel dimension
|
||||
no_heads:
|
||||
Number of attention heads
|
||||
inf:
|
||||
Large number used to construct attention masks
|
||||
"""
|
||||
super(MSARowAttentionWithPairBias, self).__init__(
|
||||
c_m,
|
||||
c_hidden,
|
||||
no_heads,
|
||||
pair_bias=True,
|
||||
c_z=c_z,
|
||||
inf=inf,
|
||||
)
|
||||
|
||||
|
||||
class MSAColumnAttention(nn.Module):
|
||||
"""
|
||||
Implements Algorithm 8.
|
||||
|
||||
By rights, this should also be a subclass of MSAAttention. Alas,
|
||||
most inheritance isn't supported by TorchScript.
|
||||
"""
|
||||
|
||||
def __init__(self, c_m, c_hidden, no_heads, inf=1e9):
|
||||
"""
|
||||
Args:
|
||||
c_m:
|
||||
MSA channel dimension
|
||||
c_hidden:
|
||||
Per-head hidden channel dimension
|
||||
no_heads:
|
||||
Number of attention heads
|
||||
inf:
|
||||
Large number used to construct attention masks
|
||||
"""
|
||||
super(MSAColumnAttention, self).__init__()
|
||||
|
||||
self.c_m = c_m
|
||||
self.c_hidden = c_hidden
|
||||
self.no_heads = no_heads
|
||||
self.inf = inf
|
||||
|
||||
self._msa_att = MSAAttention(
|
||||
c_in=c_m,
|
||||
c_hidden=c_hidden,
|
||||
no_heads=no_heads,
|
||||
pair_bias=False,
|
||||
c_z=None,
|
||||
inf=inf,
|
||||
)
|
||||
|
||||
def forward(self,
|
||||
m: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
m:
|
||||
[*, N_seq, N_res, C_m] MSA embedding
|
||||
mask:
|
||||
[*, N_seq, N_res] MSA mask
|
||||
chunk_size:
|
||||
Size of chunks into which the inputs are split along their
|
||||
batch dimensions. A low value decreases memory overhead at the
|
||||
cost of slower execution. Chunking is not performed by default.
|
||||
"""
|
||||
# [*, N_res, N_seq, C_in]
|
||||
m = m.transpose(-2, -3)
|
||||
|
||||
m = self._msa_att(m, chunk_size=chunk_size)
|
||||
|
||||
# [*, N_seq, N_res, C_in]
|
||||
m = m.transpose(-2, -3)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
class MSAColumnGlobalAttention(nn.Module):
|
||||
def __init__(
|
||||
self, c_in, c_hidden, no_heads, inf=1e9, eps=1e-10,
|
||||
):
|
||||
super(MSAColumnGlobalAttention, self).__init__()
|
||||
|
||||
self.c_in = c_in
|
||||
self.c_hidden = c_hidden
|
||||
self.no_heads = no_heads
|
||||
self.inf = inf
|
||||
self.eps = eps
|
||||
|
||||
self.layer_norm_m = nn.LayerNorm(c_in)
|
||||
|
||||
self.global_attention = GlobalAttention(
|
||||
c_in=c_in,
|
||||
c_hidden=c_hidden,
|
||||
no_heads=no_heads,
|
||||
inf=inf,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(self,
|
||||
m: torch.Tensor,
|
||||
chunk_size: int,
|
||||
) -> torch.Tensor:
|
||||
mha_input = {
|
||||
"m": m,
|
||||
}
|
||||
return chunk_layer(
|
||||
self.global_attention,
|
||||
mha_input,
|
||||
chunk_size=chunk_size,
|
||||
no_batch_dims=len(m.shape[:-2]),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
m: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
n_seq, n_res, c_in = m.shape[-3:]
|
||||
|
||||
# [*, N_res, N_seq, C_in]
|
||||
m = m.transpose(-2, -3)
|
||||
|
||||
# [*, N_res, N_seq, C_in]
|
||||
m = self.layer_norm_m(m)
|
||||
|
||||
if chunk_size is not None:
|
||||
m = self._chunk(m, chunk_size)
|
||||
else:
|
||||
m = self.global_attention(m=m)
|
||||
|
||||
# [*, N_seq, N_res, C_in]
|
||||
m = m.transpose(-2, -3)
|
||||
|
||||
return m
|
|
@ -0,0 +1,129 @@
|
|||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .primitives import Linear
|
||||
from .tensor_utils import chunk_layer
|
||||
|
||||
|
||||
class OuterProductMean(nn.Module):
|
||||
"""
|
||||
Implements Algorithm 10.
|
||||
"""
|
||||
|
||||
def __init__(self, c_m, c_z, c_hidden, eps=1e-3):
|
||||
"""
|
||||
Args:
|
||||
c_m:
|
||||
MSA embedding channel dimension
|
||||
c_z:
|
||||
Pair embedding channel dimension
|
||||
c_hidden:
|
||||
Hidden channel dimension
|
||||
"""
|
||||
super(OuterProductMean, self).__init__()
|
||||
|
||||
self.c_m = c_m
|
||||
self.c_z = c_z
|
||||
self.c_hidden = c_hidden
|
||||
self.eps = eps
|
||||
|
||||
self.layer_norm = nn.LayerNorm(c_m)
|
||||
self.linear_1 = Linear(c_m, c_hidden)
|
||||
self.linear_2 = Linear(c_m, c_hidden)
|
||||
self.linear_out = Linear(c_hidden ** 2, c_z, init="final")
|
||||
|
||||
def _opm(self, a, b):
|
||||
# [*, N_res, N_res, C, C]
|
||||
outer = torch.einsum("...bac,...dae->...bdce", a, b)
|
||||
|
||||
# [*, N_res, N_res, C * C]
|
||||
outer = outer.reshape(outer.shape[:-2] + (-1,))
|
||||
|
||||
# [*, N_res, N_res, C_z]
|
||||
outer = self.linear_out(outer)
|
||||
|
||||
return outer
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(self,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
chunk_size: int
|
||||
) -> torch.Tensor:
|
||||
# Since the "batch dim" in this case is not a true batch dimension
|
||||
# (in that the shape of the output depends on it), we need to
|
||||
# iterate over it ourselves
|
||||
a_reshape = a.reshape((-1,) + a.shape[-3:])
|
||||
b_reshape = b.reshape((-1,) + b.shape[-3:])
|
||||
out = []
|
||||
for a_prime, b_prime in zip(a_reshape, b_reshape):
|
||||
outer = chunk_layer(
|
||||
partial(self._opm, b=b_prime),
|
||||
{"a": a_prime},
|
||||
chunk_size=chunk_size,
|
||||
no_batch_dims=1,
|
||||
)
|
||||
out.append(outer)
|
||||
outer = torch.stack(out, dim=0)
|
||||
outer = outer.reshape(a.shape[:-3] + outer.shape[1:])
|
||||
|
||||
return outer
|
||||
|
||||
def forward(self,
|
||||
m: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
m:
|
||||
[*, N_seq, N_res, C_m] MSA embedding
|
||||
mask:
|
||||
[*, N_seq, N_res] MSA mask
|
||||
Returns:
|
||||
[*, N_res, N_res, C_z] pair embedding update
|
||||
"""
|
||||
if mask is None:
|
||||
mask = m.new_ones(m.shape[:-1])
|
||||
|
||||
# [*, N_seq, N_res, C_m]
|
||||
m = self.layer_norm(m)
|
||||
|
||||
# [*, N_seq, N_res, C]
|
||||
mask = mask.unsqueeze(-1)
|
||||
a = self.linear_1(m) * mask
|
||||
b = self.linear_2(m) * mask
|
||||
|
||||
a = a.transpose(-2, -3)
|
||||
b = b.transpose(-2, -3)
|
||||
|
||||
if chunk_size is not None:
|
||||
outer = self._chunk(a, b, chunk_size)
|
||||
else:
|
||||
outer = self._opm(a, b)
|
||||
|
||||
# [*, N_res, N_res, 1]
|
||||
norm = torch.einsum("...abc,...adc->...bdc", mask, mask)
|
||||
|
||||
# [*, N_res, N_res, C_z]
|
||||
outer = outer / (self.eps + norm)
|
||||
|
||||
return outer
|
|
@ -0,0 +1,99 @@
|
|||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .primitives import Linear, LayerNorm
|
||||
from .tensor_utils import chunk_layer
|
||||
|
||||
|
||||
class PairTransition(nn.Module):
|
||||
"""
|
||||
Implements Algorithm 15.
|
||||
"""
|
||||
|
||||
def __init__(self, c_z, n):
|
||||
"""
|
||||
Args:
|
||||
c_z:
|
||||
Pair transition channel dimension
|
||||
n:
|
||||
Factor by which c_z is multiplied to obtain hidden channel
|
||||
dimension
|
||||
"""
|
||||
super(PairTransition, self).__init__()
|
||||
|
||||
self.c_z = c_z
|
||||
self.n = n
|
||||
|
||||
self.layer_norm = LayerNorm(self.c_z)
|
||||
self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu")
|
||||
self.relu = nn.ReLU()
|
||||
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
|
||||
|
||||
def _transition(self, z, mask):
|
||||
# [*, N_res, N_res, C_hidden]
|
||||
z = self.linear_1(z)
|
||||
z = self.relu(z)
|
||||
|
||||
# [*, N_res, N_res, C_z]
|
||||
z = self.linear_2(z) * mask
|
||||
|
||||
return z
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(self,
|
||||
z: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
chunk_size: int,
|
||||
) -> torch.Tensor:
|
||||
return chunk_layer(
|
||||
self._transition,
|
||||
{"z": z, "mask": mask},
|
||||
chunk_size=chunk_size,
|
||||
no_batch_dims=len(z.shape[:-2]),
|
||||
)
|
||||
|
||||
|
||||
def forward(self,
|
||||
z: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
z:
|
||||
[*, N_res, N_res, C_z] pair embedding
|
||||
Returns:
|
||||
[*, N_res, N_res, C_z] pair embedding update
|
||||
"""
|
||||
# DISCREPANCY: DeepMind forgets to apply the mask in this module.
|
||||
if mask is None:
|
||||
mask = z.new_ones(z.shape[:-1])
|
||||
|
||||
# [*, N_res, N_res, 1]
|
||||
mask = mask.unsqueeze(-1)
|
||||
|
||||
# [*, N_res, N_res, C_z]
|
||||
z = self.layer_norm(z)
|
||||
|
||||
if chunk_size is not None:
|
||||
z = self._chunk(z, mask, chunk_size)
|
||||
else:
|
||||
z = self._transition(z=z, mask=mask)
|
||||
|
||||
return z
|
|
@ -0,0 +1,529 @@
|
|||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import Optional, Callable, List, Tuple, Sequence
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .checkpointing import get_checkpoint_fn
|
||||
from .tensor_utils import (
|
||||
permute_final_dims,
|
||||
flatten_final_dims,
|
||||
_chunk_slice,
|
||||
)
|
||||
|
||||
|
||||
def _prod(nums):
|
||||
out = 1
|
||||
for n in nums:
|
||||
out = out * n
|
||||
return out
|
||||
|
||||
|
||||
def _calculate_fan(linear_weight_shape, fan="fan_in"):
|
||||
fan_out, fan_in = linear_weight_shape
|
||||
|
||||
if fan == "fan_in":
|
||||
f = fan_in
|
||||
elif fan == "fan_out":
|
||||
f = fan_out
|
||||
elif fan == "fan_avg":
|
||||
f = (fan_in + fan_out) / 2
|
||||
else:
|
||||
raise ValueError("Invalid fan option")
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def glorot_uniform_init_(weights):
|
||||
nn.init.xavier_uniform_(weights, gain=1)
|
||||
|
||||
|
||||
def final_init_(weights):
|
||||
with torch.no_grad():
|
||||
weights.fill_(0.0)
|
||||
|
||||
|
||||
def gating_init_(weights):
|
||||
with torch.no_grad():
|
||||
weights.fill_(0.0)
|
||||
|
||||
|
||||
def normal_init_(weights):
|
||||
torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
|
||||
|
||||
|
||||
def ipa_point_weights_init_(weights):
|
||||
with torch.no_grad():
|
||||
softplus_inverse_1 = 0.541324854612918
|
||||
weights.fill_(softplus_inverse_1)
|
||||
|
||||
|
||||
class Linear(nn.Linear):
|
||||
"""
|
||||
A Linear layer with built-in nonstandard initializations. Called just
|
||||
like torch.nn.Linear.
|
||||
|
||||
Implements the initializers in 1.11.4, plus some additional ones found
|
||||
in the code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
bias: bool = True,
|
||||
init: str = "default",
|
||||
init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
in_dim:
|
||||
The final dimension of inputs to the layer
|
||||
out_dim:
|
||||
The final dimension of layer outputs
|
||||
bias:
|
||||
Whether to learn an additive bias. True by default
|
||||
init:
|
||||
The initializer to use. Choose from:
|
||||
|
||||
"default": LeCun fan-in truncated normal initialization
|
||||
"relu": He initialization w/ truncated normal distribution
|
||||
"glorot": Fan-average Glorot uniform initialization
|
||||
"gating": Weights=0, Bias=1
|
||||
"normal": Normal initialization with std=1/sqrt(fan_in)
|
||||
"final": Weights=0, Bias=0
|
||||
|
||||
Overridden by init_fn if the latter is not None.
|
||||
init_fn:
|
||||
A custom initializer taking weight and bias as inputs.
|
||||
Overrides init if not None.
|
||||
"""
|
||||
super(Linear, self).__init__(in_dim, out_dim, bias=bias)
|
||||
|
||||
if bias:
|
||||
with torch.no_grad():
|
||||
self.bias.fill_(0)
|
||||
|
||||
if init_fn is not None:
|
||||
init_fn(self.weight, self.bias)
|
||||
else:
|
||||
if init == "default":
|
||||
normal_init_(self.weight)
|
||||
elif init == "relu":
|
||||
normal_init_(self.weight)
|
||||
elif init == "glorot":
|
||||
glorot_uniform_init_(self.weight)
|
||||
elif init == "gating":
|
||||
gating_init_(self.weight)
|
||||
if bias:
|
||||
with torch.no_grad():
|
||||
self.bias.fill_(1.0)
|
||||
elif init == "normal":
|
||||
normal_init_(self.weight)
|
||||
elif init == "final":
|
||||
final_init_(self.weight)
|
||||
else:
|
||||
raise ValueError("Invalid init string.")
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
|
||||
def __init__(self, c_in, eps=1e-5):
|
||||
super(LayerNorm, self).__init__()
|
||||
|
||||
self.c_in = (c_in,)
|
||||
self.eps = eps
|
||||
|
||||
self.weight = nn.Parameter(torch.ones(c_in))
|
||||
self.bias = nn.Parameter(torch.zeros(c_in))
|
||||
|
||||
def forward(self, x):
|
||||
out = nn.functional.layer_norm(
|
||||
x,
|
||||
self.c_in,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.eps,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
"""
|
||||
Softmax, but without automatic casting to fp32 when the input is of
|
||||
type bfloat16
|
||||
"""
|
||||
s = torch.nn.functional.softmax(t, dim=dim)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
#@torch.jit.script
|
||||
def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
|
||||
biases: List[torch.Tensor]) -> torch.Tensor:
|
||||
# [*, H, Q, C_hidden]
|
||||
query = permute_final_dims(query, (1, 0, 2))
|
||||
|
||||
# [*, H, C_hidden, K]
|
||||
key = permute_final_dims(key, (1, 2, 0))
|
||||
|
||||
# [*, H, V, C_hidden]
|
||||
value = permute_final_dims(value, (1, 0, 2))
|
||||
|
||||
# [*, H, Q, K]
|
||||
a = torch.matmul(query, key)
|
||||
|
||||
for b in biases:
|
||||
a += b
|
||||
|
||||
a = softmax(a, -1)
|
||||
|
||||
# [*, H, Q, C_hidden]
|
||||
a = torch.matmul(a, value)
|
||||
|
||||
# [*, Q, H, C_hidden]
|
||||
a = a.transpose(-2, -3)
|
||||
|
||||
return a
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _attention_chunked_trainable(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
biases,
|
||||
chunk_size,
|
||||
chunk_dim,
|
||||
checkpoint,
|
||||
):
|
||||
if (checkpoint and len(biases) > 2):
|
||||
raise ValueError("Checkpointed version permits only permits two bias terms")
|
||||
|
||||
def _checkpointable_attention(q, k, v, b1, b2):
|
||||
bs = [b for b in [b1, b2] if b is not None]
|
||||
return _attention(q, k, v, bs)
|
||||
|
||||
o_chunks = []
|
||||
checkpoint_fn = get_checkpoint_fn()
|
||||
count = query.shape[chunk_dim]
|
||||
for start in range(0, count, chunk_size):
|
||||
end = start + chunk_size
|
||||
idx = [slice(None)] * len(query.shape)
|
||||
idx[chunk_dim] = slice(start, end)
|
||||
idx_tup = tuple(idx)
|
||||
q_chunk = query[idx_tup]
|
||||
k_chunk = key[idx_tup]
|
||||
v_chunk = value[idx_tup]
|
||||
|
||||
def _slice_bias(b):
|
||||
idx[chunk_dim] = (slice(start, end) if b.shape[chunk_dim] != 1 else slice(None))
|
||||
return b[tuple(idx)]
|
||||
|
||||
if (checkpoint):
|
||||
bias_1_chunk, bias_2_chunk = [
|
||||
_slice_bias(b) if b is not None else None for b in (biases + [None, None])[:2]
|
||||
]
|
||||
|
||||
o_chunk = checkpoint_fn(_checkpointable_attention, q_chunk, k_chunk, v_chunk,
|
||||
bias_1_chunk, bias_2_chunk)
|
||||
else:
|
||||
bias_chunks = [_slice_bias(b) for b in biases]
|
||||
|
||||
o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks)
|
||||
|
||||
o_chunks.append(o_chunk)
|
||||
|
||||
o = torch.cat(o_chunks, dim=chunk_dim)
|
||||
return o
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
Standard multi-head attention using AlphaFold's default layer
|
||||
initialization. Allows multiple bias vectors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c_q: int,
|
||||
c_k: int,
|
||||
c_v: int,
|
||||
c_hidden: int,
|
||||
no_heads: int,
|
||||
gating: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
c_q:
|
||||
Input dimension of query data
|
||||
c_k:
|
||||
Input dimension of key data
|
||||
c_v:
|
||||
Input dimension of value data
|
||||
c_hidden:
|
||||
Per-head hidden dimension
|
||||
no_heads:
|
||||
Number of attention heads
|
||||
gating:
|
||||
Whether the output should be gated using query data
|
||||
"""
|
||||
super(Attention, self).__init__()
|
||||
|
||||
self.c_q = c_q
|
||||
self.c_k = c_k
|
||||
self.c_v = c_v
|
||||
self.c_hidden = c_hidden
|
||||
self.no_heads = no_heads
|
||||
self.gating = gating
|
||||
|
||||
# DISCREPANCY: c_hidden is not the per-head channel dimension, as
|
||||
# stated in the supplement, but the overall channel dimension.
|
||||
|
||||
self.linear_q = Linear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot")
|
||||
self.linear_k = Linear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot")
|
||||
self.linear_v = Linear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot")
|
||||
self.linear_o = Linear(self.c_hidden * self.no_heads, self.c_q, init="final")
|
||||
|
||||
self.linear_g = None
|
||||
if self.gating:
|
||||
self.linear_g = Linear(self.c_q, self.c_hidden * self.no_heads, init="gating")
|
||||
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def _prep_qkv(self, q_x: torch.Tensor,
|
||||
kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# [*, Q/K/V, H * C_hidden]
|
||||
q = self.linear_q(q_x)
|
||||
k = self.linear_k(kv_x)
|
||||
v = self.linear_v(kv_x)
|
||||
|
||||
# [*, Q/K, H, C_hidden]
|
||||
q = q.view(q.shape[:-1] + (self.no_heads, -1))
|
||||
k = k.view(k.shape[:-1] + (self.no_heads, -1))
|
||||
v = v.view(v.shape[:-1] + (self.no_heads, -1))
|
||||
|
||||
q /= math.sqrt(self.c_hidden)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor:
|
||||
if (self.linear_g is not None):
|
||||
g = self.sigmoid(self.linear_g(q_x))
|
||||
|
||||
# [*, Q, H, C_hidden]
|
||||
g = g.view(g.shape[:-1] + (self.no_heads, -1))
|
||||
o = o * g
|
||||
|
||||
# [*, Q, H * C_hidden]
|
||||
o = flatten_final_dims(o, 2)
|
||||
|
||||
# [*, Q, C_q]
|
||||
o = self.linear_o(o)
|
||||
|
||||
return o
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q_x: torch.Tensor,
|
||||
kv_x: torch.Tensor,
|
||||
biases: Optional[List[torch.Tensor]] = None,
|
||||
use_lma: bool = False,
|
||||
q_chunk_size: Optional[int] = None,
|
||||
kv_chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
q_x:
|
||||
[*, Q, C_q] query data
|
||||
kv_x:
|
||||
[*, K, C_k] key data
|
||||
biases:
|
||||
List of biases that broadcast to [*, H, Q, K]
|
||||
use_lma:
|
||||
Whether to use low-memory attention
|
||||
q_chunk_size:
|
||||
Query chunk size (for LMA)
|
||||
kv_chunk_size:
|
||||
Key/Value chunk size (for LMA)
|
||||
Returns
|
||||
[*, Q, C_q] attention update
|
||||
"""
|
||||
if (biases is None):
|
||||
biases = []
|
||||
if (use_lma and (q_chunk_size is None or kv_chunk_size is None)):
|
||||
raise ValueError("If use_lma is specified, q_chunk_size and kv_chunk_size must "
|
||||
"be provided")
|
||||
|
||||
q, k, v = self._prep_qkv(q_x, kv_x)
|
||||
|
||||
if (use_lma):
|
||||
biases = [b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) for b in biases]
|
||||
|
||||
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
|
||||
else:
|
||||
o = _attention(q, k, v, biases)
|
||||
|
||||
o = self._wrap_up(o, q_x)
|
||||
|
||||
return o
|
||||
|
||||
|
||||
class GlobalAttention(nn.Module):
|
||||
|
||||
def __init__(self, c_in, c_hidden, no_heads, inf, eps):
|
||||
super(GlobalAttention, self).__init__()
|
||||
|
||||
self.c_in = c_in
|
||||
self.c_hidden = c_hidden
|
||||
self.no_heads = no_heads
|
||||
self.inf = inf
|
||||
self.eps = eps
|
||||
|
||||
self.linear_q = Linear(c_in, c_hidden * no_heads, bias=False, init="glorot")
|
||||
|
||||
self.linear_k = Linear(
|
||||
c_in,
|
||||
c_hidden,
|
||||
bias=False,
|
||||
init="glorot",
|
||||
)
|
||||
self.linear_v = Linear(
|
||||
c_in,
|
||||
c_hidden,
|
||||
bias=False,
|
||||
init="glorot",
|
||||
)
|
||||
self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating")
|
||||
self.linear_o = Linear(c_hidden * no_heads, c_in, init="final")
|
||||
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
# [*, N_res, C_in]
|
||||
q = torch.sum(m * mask.unsqueeze(-1),
|
||||
dim=-2) / (torch.sum(mask, dim=-1)[..., None] + self.eps)
|
||||
|
||||
# [*, N_res, H * C_hidden]
|
||||
q = self.linear_q(q)
|
||||
q *= (self.c_hidden**(-0.5))
|
||||
|
||||
# [*, N_res, H, C_hidden]
|
||||
q = q.view(q.shape[:-1] + (self.no_heads, -1))
|
||||
|
||||
# [*, N_res, N_seq, C_hidden]
|
||||
k = self.linear_k(m)
|
||||
v = self.linear_v(m)
|
||||
|
||||
# [*, N_res, H, N_seq]
|
||||
a = torch.matmul(
|
||||
q,
|
||||
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
|
||||
)
|
||||
bias = (self.inf * (mask - 1))[..., :, None, :]
|
||||
a += bias
|
||||
a = softmax(a)
|
||||
|
||||
# [*, N_res, H, C_hidden]
|
||||
o = torch.matmul(
|
||||
a,
|
||||
v,
|
||||
)
|
||||
|
||||
# [*, N_res, N_seq, C_hidden]
|
||||
g = self.sigmoid(self.linear_g(m))
|
||||
|
||||
# [*, N_res, N_seq, H, C_hidden]
|
||||
g = g.view(g.shape[:-1] + (self.no_heads, -1))
|
||||
|
||||
# [*, N_res, N_seq, H, C_hidden]
|
||||
o = o.unsqueeze(-3) * g
|
||||
|
||||
# [*, N_res, N_seq, H * C_hidden]
|
||||
o = o.reshape(o.shape[:-2] + (-1,))
|
||||
|
||||
# [*, N_res, N_seq, C_in]
|
||||
m = self.linear_o(o)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
def _lma(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
biases: List[torch.Tensor],
|
||||
q_chunk_size: int,
|
||||
kv_chunk_size: int,
|
||||
):
|
||||
no_q, no_kv = q.shape[-3], k.shape[-3]
|
||||
|
||||
# [*, Q, H, C_hidden]
|
||||
o = q.new_zeros(q.shape)
|
||||
for q_s in range(0, no_q, q_chunk_size):
|
||||
q_chunk = q[..., q_s:q_s + q_chunk_size, :, :]
|
||||
large_bias_chunks = [b[..., q_s:q_s + q_chunk_size, :] for b in biases]
|
||||
|
||||
maxes = []
|
||||
weights = []
|
||||
values = []
|
||||
for kv_s in range(0, no_kv, kv_chunk_size):
|
||||
k_chunk = k[..., kv_s:kv_s + kv_chunk_size, :, :]
|
||||
v_chunk = v[..., kv_s:kv_s + kv_chunk_size, :, :]
|
||||
small_bias_chunks = [b[..., kv_s:kv_s + kv_chunk_size] for b in large_bias_chunks]
|
||||
|
||||
a = torch.einsum(
|
||||
"...qhd,...khd->...hqk",
|
||||
q_chunk,
|
||||
k_chunk,
|
||||
)
|
||||
|
||||
for b in small_bias_chunks:
|
||||
a += b
|
||||
|
||||
a = a.transpose(-2, -3)
|
||||
|
||||
max_a = torch.max(a, dim=-1, keepdim=True)[0]
|
||||
exp_a = torch.exp(a - max_a)
|
||||
exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a)
|
||||
|
||||
maxes.append(max_a.detach().squeeze(-1))
|
||||
weights.append(torch.sum(exp_a, dim=-1))
|
||||
values.append(exp_v)
|
||||
|
||||
chunk_max = torch.stack(maxes, dim=-3)
|
||||
chunk_weights = torch.stack(weights, dim=-3)
|
||||
chunk_values = torch.stack(values, dim=-4)
|
||||
|
||||
global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0]
|
||||
max_diffs = torch.exp(chunk_max - global_max)
|
||||
chunk_values *= max_diffs.unsqueeze(-1)
|
||||
chunk_weights *= max_diffs
|
||||
|
||||
all_values = torch.sum(chunk_values, dim=-4)
|
||||
all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4)
|
||||
|
||||
q_chunk_out = all_values / all_weights
|
||||
|
||||
o[..., q_s:q_s + q_chunk_size, :, :] = q_chunk_out
|
||||
|
||||
return o
|
|
@ -0,0 +1,408 @@
|
|||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional
|
||||
|
||||
|
||||
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
|
||||
zero_index = -1 * len(inds)
|
||||
first_inds = list(range(len(tensor.shape[:zero_index])))
|
||||
return tensor.permute(first_inds + [zero_index + i for i in inds])
|
||||
|
||||
|
||||
def flatten_final_dims(t: torch.Tensor, no_dims: int):
|
||||
return t.reshape(t.shape[:-no_dims] + (-1,))
|
||||
|
||||
|
||||
def masked_mean(mask, value, dim, eps=1e-4):
|
||||
mask = mask.expand(*value.shape)
|
||||
return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
|
||||
|
||||
|
||||
def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64):
|
||||
boundaries = torch.linspace(
|
||||
min_bin, max_bin, no_bins - 1, device=pts.device
|
||||
)
|
||||
dists = torch.sqrt(
|
||||
torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1)
|
||||
)
|
||||
return torch.bucketize(dists, boundaries)
|
||||
|
||||
|
||||
def dict_multimap(fn, dicts):
|
||||
first = dicts[0]
|
||||
new_dict = {}
|
||||
for k, v in first.items():
|
||||
all_v = [d[k] for d in dicts]
|
||||
if type(v) is dict:
|
||||
new_dict[k] = dict_multimap(fn, all_v)
|
||||
else:
|
||||
new_dict[k] = fn(all_v)
|
||||
|
||||
return new_dict
|
||||
|
||||
|
||||
def one_hot(x, v_bins):
|
||||
reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
|
||||
diffs = x[..., None] - reshaped_bins
|
||||
am = torch.argmin(torch.abs(diffs), dim=-1)
|
||||
return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
|
||||
|
||||
|
||||
def batched_gather(data, inds, dim=0, no_batch_dims=0):
|
||||
ranges = []
|
||||
for i, s in enumerate(data.shape[:no_batch_dims]):
|
||||
r = torch.arange(s)
|
||||
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
|
||||
ranges.append(r)
|
||||
|
||||
remaining_dims = [
|
||||
slice(None) for _ in range(len(data.shape) - no_batch_dims)
|
||||
]
|
||||
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
|
||||
ranges.extend(remaining_dims)
|
||||
return data[ranges]
|
||||
|
||||
|
||||
# With tree_map, a poor man's JAX tree_map
|
||||
def dict_map(fn, dic, leaf_type):
|
||||
new_dict = {}
|
||||
for k, v in dic.items():
|
||||
if type(v) is dict:
|
||||
new_dict[k] = dict_map(fn, v, leaf_type)
|
||||
else:
|
||||
new_dict[k] = tree_map(fn, v, leaf_type)
|
||||
|
||||
return new_dict
|
||||
|
||||
|
||||
def tree_map(fn, tree, leaf_type):
|
||||
if isinstance(tree, dict):
|
||||
return dict_map(fn, tree, leaf_type)
|
||||
elif isinstance(tree, list):
|
||||
return [tree_map(fn, x, leaf_type) for x in tree]
|
||||
elif isinstance(tree, tuple):
|
||||
return tuple([tree_map(fn, x, leaf_type) for x in tree])
|
||||
elif isinstance(tree, leaf_type):
|
||||
return fn(tree)
|
||||
else:
|
||||
print(type(tree))
|
||||
raise ValueError("Not supported")
|
||||
|
||||
|
||||
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
|
||||
|
||||
def _fetch_dims(tree):
|
||||
shapes = []
|
||||
tree_type = type(tree)
|
||||
if tree_type is dict:
|
||||
for v in tree.values():
|
||||
shapes.extend(_fetch_dims(v))
|
||||
elif tree_type is list or tree_type is tuple:
|
||||
for t in tree:
|
||||
shapes.extend(_fetch_dims(t))
|
||||
elif tree_type is torch.Tensor:
|
||||
shapes.append(tree.shape)
|
||||
else:
|
||||
raise ValueError("Not supported")
|
||||
|
||||
return shapes
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _flat_idx_to_idx(
|
||||
flat_idx: int,
|
||||
dims: Tuple[int],
|
||||
) -> Tuple[int]:
|
||||
idx = []
|
||||
for d in reversed(dims):
|
||||
idx.append(flat_idx % d)
|
||||
flat_idx = flat_idx // d
|
||||
|
||||
return tuple(reversed(idx))
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _get_minimal_slice_set(
|
||||
start: Sequence[int],
|
||||
end: Sequence[int],
|
||||
dims: int,
|
||||
start_edges: Optional[Sequence[bool]] = None,
|
||||
end_edges: Optional[Sequence[bool]] = None,
|
||||
) -> Sequence[Tuple[int]]:
|
||||
"""
|
||||
Produces an ordered sequence of tensor slices that, when used in
|
||||
sequence on a tensor with shape dims, yields tensors that contain every
|
||||
leaf in the contiguous range [start, end]. Care is taken to yield a
|
||||
short sequence of slices, and perhaps even the shortest possible (I'm
|
||||
pretty sure it's the latter).
|
||||
|
||||
end is INCLUSIVE.
|
||||
"""
|
||||
# start_edges and end_edges both indicate whether, starting from any given
|
||||
# dimension, the start/end index is at the top/bottom edge of the
|
||||
# corresponding tensor, modeled as a tree
|
||||
def reduce_edge_list(ll):
|
||||
tally = 1
|
||||
for i in range(len(ll)):
|
||||
reversed_idx = -1 * (i + 1)
|
||||
ll[reversed_idx] *= tally
|
||||
tally = ll[reversed_idx]
|
||||
|
||||
if(start_edges is None):
|
||||
start_edges = [s == 0 for s in start]
|
||||
reduce_edge_list(start_edges)
|
||||
if(end_edges is None):
|
||||
end_edges = [e == (d - 1) for e,d in zip(end, dims)]
|
||||
reduce_edge_list(end_edges)
|
||||
|
||||
# Base cases. Either start/end are empty and we're done, or the final,
|
||||
# one-dimensional tensor can be simply sliced
|
||||
if(len(start) == 0):
|
||||
return [tuple()]
|
||||
elif(len(start) == 1):
|
||||
return [(slice(start[0], end[0] + 1),)]
|
||||
|
||||
slices = []
|
||||
path = []
|
||||
|
||||
# Dimensions common to start and end can be selected directly
|
||||
for s,e in zip(start, end):
|
||||
if(s == e):
|
||||
path.append(slice(s, s + 1))
|
||||
else:
|
||||
break
|
||||
|
||||
path = tuple(path)
|
||||
divergence_idx = len(path)
|
||||
|
||||
# start == end, and we're done
|
||||
if(divergence_idx == len(dims)):
|
||||
return [tuple(path)]
|
||||
|
||||
def upper():
|
||||
sdi = start[divergence_idx]
|
||||
return [
|
||||
path + (slice(sdi, sdi + 1),) + s for s in
|
||||
_get_minimal_slice_set(
|
||||
start[divergence_idx + 1:],
|
||||
[d - 1 for d in dims[divergence_idx + 1:]],
|
||||
dims[divergence_idx + 1:],
|
||||
start_edges=start_edges[divergence_idx + 1:],
|
||||
end_edges=[1 for _ in end_edges[divergence_idx + 1:]]
|
||||
)
|
||||
]
|
||||
|
||||
def lower():
|
||||
edi = end[divergence_idx]
|
||||
return [
|
||||
path + (slice(edi, edi + 1),) + s for s in
|
||||
_get_minimal_slice_set(
|
||||
[0 for _ in start[divergence_idx + 1:]],
|
||||
end[divergence_idx + 1:],
|
||||
dims[divergence_idx + 1:],
|
||||
start_edges=[1 for _ in start_edges[divergence_idx + 1:]],
|
||||
end_edges=end_edges[divergence_idx + 1:],
|
||||
)
|
||||
]
|
||||
|
||||
# If both start and end are at the edges of the subtree rooted at
|
||||
# divergence_idx, we can just select the whole subtree at once
|
||||
if(start_edges[divergence_idx] and end_edges[divergence_idx]):
|
||||
slices.append(
|
||||
path + (slice(start[divergence_idx], end[divergence_idx] + 1),)
|
||||
)
|
||||
# If just start is at the edge, we can grab almost all of the subtree,
|
||||
# treating only the ragged bottom edge as an edge case
|
||||
elif(start_edges[divergence_idx]):
|
||||
slices.append(
|
||||
path + (slice(start[divergence_idx], end[divergence_idx]),)
|
||||
)
|
||||
slices.extend(lower())
|
||||
# Analogous to the previous case, but the top is ragged this time
|
||||
elif(end_edges[divergence_idx]):
|
||||
slices.extend(upper())
|
||||
slices.append(
|
||||
path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),)
|
||||
)
|
||||
# If both sides of the range are ragged, we need to handle both sides
|
||||
# separately. If there's contiguous meat in between them, we can index it
|
||||
# in one big chunk
|
||||
else:
|
||||
slices.extend(upper())
|
||||
middle_ground = end[divergence_idx] - start[divergence_idx]
|
||||
if(middle_ground > 1):
|
||||
slices.append(
|
||||
path + (slice(start[divergence_idx] + 1, end[divergence_idx]),)
|
||||
)
|
||||
slices.extend(lower())
|
||||
|
||||
return [tuple(s) for s in slices]
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk_slice(
|
||||
t: torch.Tensor,
|
||||
flat_start: int,
|
||||
flat_end: int,
|
||||
no_batch_dims: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Equivalent to
|
||||
|
||||
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
|
||||
|
||||
but without the need for the initial reshape call, which can be
|
||||
memory-intensive in certain situations. The only reshape operations
|
||||
in this function are performed on sub-tensors that scale with
|
||||
(flat_end - flat_start), the chunk size.
|
||||
"""
|
||||
|
||||
batch_dims = t.shape[:no_batch_dims]
|
||||
start_idx = list(_flat_idx_to_idx(flat_start, batch_dims))
|
||||
# _get_minimal_slice_set is inclusive
|
||||
end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims))
|
||||
|
||||
# Get an ordered list of slices to perform
|
||||
slices = _get_minimal_slice_set(
|
||||
start_idx,
|
||||
end_idx,
|
||||
batch_dims,
|
||||
)
|
||||
|
||||
sliced_tensors = [t[s] for s in slices]
|
||||
|
||||
return torch.cat(
|
||||
[s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors]
|
||||
)
|
||||
|
||||
|
||||
def chunk_layer(
|
||||
layer: Callable,
|
||||
inputs: Dict[str, Any],
|
||||
chunk_size: int,
|
||||
no_batch_dims: int,
|
||||
low_mem: bool = False,
|
||||
) -> Any:
|
||||
"""
|
||||
Implements the "chunking" procedure described in section 1.11.8.
|
||||
|
||||
Layer outputs and inputs are assumed to be simple "pytrees,"
|
||||
consisting only of (arbitrarily nested) lists, tuples, and dicts with
|
||||
torch.Tensor leaves.
|
||||
|
||||
Args:
|
||||
layer:
|
||||
The layer to be applied chunk-wise
|
||||
inputs:
|
||||
A (non-nested) dictionary of keyworded inputs. All leaves must
|
||||
be tensors and must share the same batch dimensions.
|
||||
chunk_size:
|
||||
The number of sub-batches per chunk. If multiple batch
|
||||
dimensions are specified, a "sub-batch" is defined as a single
|
||||
indexing of all batch dimensions simultaneously (s.t. the
|
||||
number of sub-batches is the product of the batch dimensions).
|
||||
no_batch_dims:
|
||||
How many of the initial dimensions of each input tensor can
|
||||
be considered batch dimensions.
|
||||
low_mem:
|
||||
Avoids flattening potentially large input tensors. Unnecessary
|
||||
in most cases, and is ever so slightly slower than the default
|
||||
setting.
|
||||
Returns:
|
||||
The reassembled output of the layer on the inputs.
|
||||
"""
|
||||
if not (len(inputs) > 0):
|
||||
raise ValueError("Must provide at least one input")
|
||||
|
||||
initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
|
||||
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
|
||||
|
||||
def _prep_inputs(t):
|
||||
# TODO: make this more memory efficient. This sucks
|
||||
if(not low_mem):
|
||||
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
|
||||
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
|
||||
t = t.reshape(-1, *t.shape[no_batch_dims:])
|
||||
else:
|
||||
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
|
||||
return t
|
||||
|
||||
prepped_inputs = tensor_tree_map(_prep_inputs, inputs)
|
||||
|
||||
flat_batch_dim = 1
|
||||
for d in orig_batch_dims:
|
||||
flat_batch_dim *= d
|
||||
|
||||
no_chunks = flat_batch_dim // chunk_size + (
|
||||
flat_batch_dim % chunk_size != 0
|
||||
)
|
||||
|
||||
i = 0
|
||||
out = None
|
||||
for _ in range(no_chunks):
|
||||
# Chunk the input
|
||||
if(not low_mem):
|
||||
select_chunk = (
|
||||
lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t
|
||||
)
|
||||
else:
|
||||
select_chunk = (
|
||||
partial(
|
||||
_chunk_slice,
|
||||
flat_start=i,
|
||||
flat_end=min(flat_batch_dim, i + chunk_size),
|
||||
no_batch_dims=len(orig_batch_dims)
|
||||
)
|
||||
)
|
||||
|
||||
chunks = tensor_tree_map(select_chunk, prepped_inputs)
|
||||
|
||||
# Run the layer on the chunk
|
||||
output_chunk = layer(**chunks)
|
||||
|
||||
# Allocate space for the output
|
||||
if out is None:
|
||||
allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:])
|
||||
out = tensor_tree_map(allocate, output_chunk)
|
||||
|
||||
# Put the chunk in its pre-allocated space
|
||||
out_type = type(output_chunk)
|
||||
if out_type is dict:
|
||||
def assign(d1, d2):
|
||||
for k, v in d1.items():
|
||||
if type(v) is dict:
|
||||
assign(v, d2[k])
|
||||
else:
|
||||
v[i : i + chunk_size] = d2[k]
|
||||
|
||||
assign(out, output_chunk)
|
||||
elif out_type is tuple:
|
||||
for x1, x2 in zip(out, output_chunk):
|
||||
x1[i : i + chunk_size] = x2
|
||||
elif out_type is torch.Tensor:
|
||||
out[i : i + chunk_size] = output_chunk
|
||||
else:
|
||||
raise ValueError("Not supported")
|
||||
|
||||
i += chunk_size
|
||||
|
||||
reshape = lambda t: t.view(orig_batch_dims + t.shape[1:])
|
||||
out = tensor_tree_map(reshape, out)
|
||||
|
||||
return out
|
|
@ -0,0 +1,139 @@
|
|||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partialmethod, partial
|
||||
import math
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .primitives import Linear, LayerNorm, Attention
|
||||
from .tensor_utils import (
|
||||
chunk_layer,
|
||||
permute_final_dims,
|
||||
flatten_final_dims,
|
||||
)
|
||||
|
||||
|
||||
class TriangleAttention(nn.Module):
|
||||
def __init__(
|
||||
self, c_in, c_hidden, no_heads, starting, inf=1e9
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
c_in:
|
||||
Input channel dimension
|
||||
c_hidden:
|
||||
Overall hidden channel dimension (not per-head)
|
||||
no_heads:
|
||||
Number of attention heads
|
||||
"""
|
||||
super(TriangleAttention, self).__init__()
|
||||
|
||||
self.c_in = c_in
|
||||
self.c_hidden = c_hidden
|
||||
self.no_heads = no_heads
|
||||
self.starting = starting
|
||||
self.inf = inf
|
||||
|
||||
self.layer_norm = LayerNorm(self.c_in)
|
||||
|
||||
self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
|
||||
|
||||
self.mha = Attention(
|
||||
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(self,
|
||||
x: torch.Tensor,
|
||||
biases: List[torch.Tensor],
|
||||
chunk_size: int,
|
||||
) -> torch.Tensor:
|
||||
mha_inputs = {
|
||||
"q_x": x,
|
||||
"kv_x": x,
|
||||
"biases": biases,
|
||||
}
|
||||
return chunk_layer(
|
||||
partial(self.mha),
|
||||
mha_inputs,
|
||||
chunk_size=chunk_size,
|
||||
no_batch_dims=len(x.shape[:-2]),
|
||||
)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
[*, I, J, C_in] input tensor (e.g. the pair representation)
|
||||
Returns:
|
||||
[*, I, J, C_in] output tensor
|
||||
"""
|
||||
if mask is None:
|
||||
# [*, I, J]
|
||||
mask = x.new_ones(
|
||||
x.shape[:-1],
|
||||
)
|
||||
|
||||
# Shape annotations assume self.starting. Else, I and J are flipped
|
||||
if not self.starting:
|
||||
x = x.transpose(-2, -3)
|
||||
mask = mask.transpose(-1, -2)
|
||||
|
||||
# [*, I, J, C_in]
|
||||
x = self.layer_norm(x)
|
||||
|
||||
# [*, I, 1, 1, J]
|
||||
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
|
||||
|
||||
# [*, H, I, J]
|
||||
triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
|
||||
|
||||
# [*, 1, H, I, J]
|
||||
triangle_bias = triangle_bias.unsqueeze(-4)
|
||||
|
||||
biases = [mask_bias, triangle_bias]
|
||||
|
||||
if chunk_size is not None:
|
||||
x = self._chunk(x, biases, chunk_size)
|
||||
else:
|
||||
x = self.mha(q_x=x, kv_x=x, biases=biases)
|
||||
|
||||
if not self.starting:
|
||||
x = x.transpose(-2, -3)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TriangleAttentionStartingNode(TriangleAttention):
|
||||
"""
|
||||
Implements Algorithm 13.
|
||||
"""
|
||||
|
||||
__init__ = partialmethod(TriangleAttention.__init__, starting=True)
|
||||
|
||||
|
||||
class TriangleAttentionEndingNode(TriangleAttention):
|
||||
"""
|
||||
Implements Algorithm 14.
|
||||
"""
|
||||
|
||||
__init__ = partialmethod(TriangleAttention.__init__, starting=False)
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partialmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .primitives import Linear, LayerNorm
|
||||
from .tensor_utils import permute_final_dims
|
||||
|
||||
|
||||
class TriangleMultiplicativeUpdate(nn.Module):
|
||||
"""
|
||||
Implements Algorithms 11 and 12.
|
||||
"""
|
||||
def __init__(self, c_z, c_hidden, _outgoing=True):
|
||||
"""
|
||||
Args:
|
||||
c_z:
|
||||
Input channel dimension
|
||||
c:
|
||||
Hidden channel dimension
|
||||
"""
|
||||
super(TriangleMultiplicativeUpdate, self).__init__()
|
||||
self.c_z = c_z
|
||||
self.c_hidden = c_hidden
|
||||
self._outgoing = _outgoing
|
||||
|
||||
self.linear_a_p = Linear(self.c_z, self.c_hidden)
|
||||
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
|
||||
self.linear_b_p = Linear(self.c_z, self.c_hidden)
|
||||
self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
|
||||
self.linear_g = Linear(self.c_z, self.c_z, init="gating")
|
||||
self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
|
||||
|
||||
self.layer_norm_in = LayerNorm(self.c_z)
|
||||
self.layer_norm_out = LayerNorm(self.c_hidden)
|
||||
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def _combine_projections(self,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError("This method needs to be overridden")
|
||||
|
||||
def forward(self,
|
||||
z: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
[*, N_res, N_res, C_z] input tensor
|
||||
mask:
|
||||
[*, N_res, N_res] input mask
|
||||
Returns:
|
||||
[*, N_res, N_res, C_z] output tensor
|
||||
"""
|
||||
if mask is None:
|
||||
mask = z.new_ones(z.shape[:-1])
|
||||
|
||||
mask = mask.unsqueeze(-1)
|
||||
|
||||
z = self.layer_norm_in(z)
|
||||
a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z))
|
||||
a = a * mask
|
||||
b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z))
|
||||
b = b * mask
|
||||
x = self._combine_projections(a, b)
|
||||
x = self.layer_norm_out(x)
|
||||
x = self.linear_z(x)
|
||||
g = self.sigmoid(self.linear_g(z))
|
||||
z = x * g
|
||||
|
||||
return z
|
||||
|
||||
|
||||
class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
|
||||
"""
|
||||
Implements Algorithm 11.
|
||||
"""
|
||||
def _combine_projections(self,
|
||||
a: torch.Tensor, # [*, N_i, N_k, C]
|
||||
b: torch.Tensor, # [*, N_j, N_k, C]
|
||||
):
|
||||
# [*, C, N_i, N_j]
|
||||
p = torch.matmul(
|
||||
permute_final_dims(a, (2, 0, 1)),
|
||||
permute_final_dims(b, (2, 1, 0)),
|
||||
)
|
||||
|
||||
# [*, N_i, N_j, C]
|
||||
return permute_final_dims(p, (1, 2, 0))
|
||||
|
||||
|
||||
class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
|
||||
"""
|
||||
Implements Algorithm 12.
|
||||
"""
|
||||
def _combine_projections(self,
|
||||
a: torch.Tensor, # [*, N_k, N_i, C]
|
||||
b: torch.Tensor, # [*, N_k, N_j, C]
|
||||
):
|
||||
# [*, C, N_i, N_j]
|
||||
p = torch.matmul(
|
||||
permute_final_dims(a, (2, 1, 0)),
|
||||
permute_final_dims(b, (2, 0, 1)),
|
||||
)
|
||||
|
||||
# [*, N_i, N_j, C]
|
||||
return permute_final_dims(p, (1, 2, 0))
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_autochunk.evoformer.evoformer import evoformer_base
|
||||
|
||||
if CODEGEN_AVAILABLE and is_compatible_with_meta():
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
|
||||
|
||||
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
|
||||
# for memory test
|
||||
# torch.cuda.reset_peak_memory_stats()
|
||||
# now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
# with torch.no_grad():
|
||||
# node1 = node.clone()
|
||||
# pair1 = pair.clone()
|
||||
# gm(node1, pair1)
|
||||
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||
# print(
|
||||
# "autochunk now mem:%.2f max mem:%.2f"
|
||||
# % (new_now_mem - now_mem, new_max_mem - now_mem)
|
||||
# )
|
||||
|
||||
# test forward
|
||||
with torch.no_grad():
|
||||
non_fx_out = model(node, pair)
|
||||
fx_out = gm(node, pair)
|
||||
|
||||
assert torch.allclose(non_fx_out[0], fx_out[0],
|
||||
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(non_fx_out[0] - fx_out[0]))
|
||||
assert torch.allclose(non_fx_out[1], fx_out[1],
|
||||
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(non_fx_out[1] - fx_out[1]))
|
||||
|
||||
|
||||
def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
|
||||
# launch colossalai
|
||||
colossalai.launch(
|
||||
config={},
|
||||
rank=rank,
|
||||
world_size=1,
|
||||
host="localhost",
|
||||
port=free_port(),
|
||||
backend="nccl",
|
||||
)
|
||||
|
||||
# build model and input
|
||||
model = evoformer_base().cuda()
|
||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
|
||||
# trace the module and replace codegen
|
||||
graph = ColoTracer().trace(
|
||||
model,
|
||||
meta_args={
|
||||
"node": node.to(torch.device("meta")),
|
||||
"pair": pair.to(torch.device("meta")),
|
||||
},
|
||||
)
|
||||
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
|
||||
interp = MetaInfoProp(gm_prop)
|
||||
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
|
||||
|
||||
# now run it twice to get meta info in graph module, not necessary
|
||||
gm = torch.fx.GraphModule(model, graph)
|
||||
interp = MetaInfoProp(gm)
|
||||
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
|
||||
|
||||
codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory)
|
||||
graph.set_codegen(codegen)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
||||
# assert we have inserted chunk
|
||||
code = graph.python_code("self").src
|
||||
assert "chunk_size" in code
|
||||
# print(code)
|
||||
|
||||
_test_fwd(model, gm, node, pair)
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta()), reason='torch version is lower than 1.12.0')
|
||||
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
|
||||
@pytest.mark.parametrize("msa_len", [32])
|
||||
@pytest.mark.parametrize("pair_len", [64])
|
||||
def test_autochunk_codegen(msa_len, pair_len, max_memory):
|
||||
run_func = partial(
|
||||
_test_autochunk_codegen,
|
||||
msa_len=msa_len,
|
||||
pair_len=pair_len,
|
||||
max_memory=max_memory,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_autochunk_codegen(0, 32, 64, 25)
|
|
@ -0,0 +1,102 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_autochunk.evoformer.evoformer import evoformer_base
|
||||
|
||||
if CODEGEN_AVAILABLE and is_compatible_with_meta():
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
|
||||
|
||||
def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
|
||||
found_regions = [i["region"] for i in chunk_infos]
|
||||
|
||||
if msa_len == 32 and pair_len == 64:
|
||||
if max_memory is None:
|
||||
target_regions = [(142, 154), (366, 373), (233, 283), (301, 351), (127, 134), (204, 228), (167, 191),
|
||||
(161, 166), (198, 203), (6, 69)]
|
||||
elif max_memory == 20:
|
||||
target_regions = [(142, 154), (369, 373), (233, 269), (301, 351)]
|
||||
elif max_memory == 25:
|
||||
target_regions = [(144, 154), (369, 370)]
|
||||
elif max_memory == 30:
|
||||
target_regions = [(144, 154)]
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
assert len(found_regions) == len(
|
||||
target_regions), "len of found regions %s doesn't equal len of target regions %s" % (
|
||||
str(found_regions),
|
||||
str(target_regions),
|
||||
)
|
||||
for region in target_regions:
|
||||
assert (region in found_regions), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % (
|
||||
str(region),
|
||||
msa_len,
|
||||
pair_len,
|
||||
max_memory,
|
||||
)
|
||||
for region in found_regions:
|
||||
assert (region in target_regions), "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % (
|
||||
str(region),
|
||||
msa_len,
|
||||
pair_len,
|
||||
max_memory,
|
||||
)
|
||||
|
||||
|
||||
def _test_autochunk_search(rank, msa_len, pair_len, max_memory):
|
||||
# launch colossalai
|
||||
colossalai.launch(
|
||||
config={},
|
||||
rank=rank,
|
||||
world_size=1,
|
||||
host="localhost",
|
||||
port=free_port(),
|
||||
backend="nccl",
|
||||
)
|
||||
|
||||
# build model and input
|
||||
model = evoformer_base().cuda()
|
||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
|
||||
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
|
||||
interp = MetaInfoProp(gm_prop)
|
||||
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
|
||||
|
||||
codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory)
|
||||
chunk_infos = codegen.chunk_infos
|
||||
assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len)
|
||||
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta()), reason="torch version is lower than 1.12.0")
|
||||
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
|
||||
@pytest.mark.parametrize("msa_len", [32])
|
||||
@pytest.mark.parametrize("pair_len", [64])
|
||||
def test_autochunk_search(msa_len, pair_len, max_memory):
|
||||
run_func = partial(
|
||||
_test_autochunk_search,
|
||||
msa_len=msa_len,
|
||||
pair_len=pair_len,
|
||||
max_memory=max_memory,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_autochunk_search(0, 32, 64, 20)
|
Loading…
Reference in New Issue