mirror of https://github.com/hpcaitech/ColossalAI
[autochunk] add autochunk feature
commit
93f62dd152
|
@ -0,0 +1,593 @@
|
||||||
|
from typing import Any, Dict, Iterable, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||||
|
|
||||||
|
if CODEGEN_AVAILABLE:
|
||||||
|
from torch.fx.graph import (
|
||||||
|
CodeGen,
|
||||||
|
PythonCode,
|
||||||
|
_custom_builtins,
|
||||||
|
_CustomBuiltin,
|
||||||
|
_format_target,
|
||||||
|
_is_from_torch,
|
||||||
|
_Namespace,
|
||||||
|
_origin_type_map,
|
||||||
|
inplace_methods,
|
||||||
|
magic_methods,
|
||||||
|
)
|
||||||
|
|
||||||
|
from torch.fx.node import Argument, Node, _get_qualified_name, _type_repr, map_arg
|
||||||
|
|
||||||
|
from .search_chunk import SearchChunk
|
||||||
|
from .utils import delete_free_var_from_last_use, find_idx_by_name, get_node_shape
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) -> str:
|
||||||
|
"""
|
||||||
|
Generate chunk slice string, eg. [:, :, chunk_idx_name:chunk_idx_name + chunk_size, :]
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_dim (int)
|
||||||
|
chunk_indice_name (str): chunk indice name
|
||||||
|
shape (List): node shape
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
new_shape (str): return slice
|
||||||
|
"""
|
||||||
|
new_shape = "["
|
||||||
|
for idx, _ in enumerate(shape):
|
||||||
|
if idx == chunk_dim:
|
||||||
|
new_shape += "%s:%s + chunk_size" % (chunk_indice_name, chunk_indice_name)
|
||||||
|
else:
|
||||||
|
new_shape += ":"
|
||||||
|
new_shape += ", "
|
||||||
|
new_shape = new_shape[:-2] + "]"
|
||||||
|
return new_shape
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_loop_start(
|
||||||
|
chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim: int, chunk_size=2
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate chunk loop start
|
||||||
|
|
||||||
|
eg. chunk_result = torch.empty([100, 100], dtype=input_node.dtype, device=input_node.device)
|
||||||
|
chunk_size = 32
|
||||||
|
for chunk_idx in range(0, 100, 32):
|
||||||
|
......
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_input (List[Node]): chunk input node
|
||||||
|
chunk_output (Node): chunk output node
|
||||||
|
chunk_ouput_dim (int): chunk output node chunk dim
|
||||||
|
chunk_size (int): chunk size. Defaults to 2.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
context (str): generated str
|
||||||
|
"""
|
||||||
|
input_node = chunk_input[0]
|
||||||
|
out_shape = get_node_shape(chunk_output)
|
||||||
|
out_str = str(list(out_shape))
|
||||||
|
context = (
|
||||||
|
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range"
|
||||||
|
% (out_str, input_node.name, input_node.name, chunk_size)
|
||||||
|
)
|
||||||
|
context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim])
|
||||||
|
return context
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_loop_end(
|
||||||
|
chunk_inputs: List[Node],
|
||||||
|
chunk_non_compute_inputs: List[Node],
|
||||||
|
chunk_outputs: Node,
|
||||||
|
chunk_outputs_dim: int,
|
||||||
|
node_list: List[Node],
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate chunk loop end
|
||||||
|
|
||||||
|
eg. chunk_result[chunk_idx:chunk_idx + chunk_size] = output_node
|
||||||
|
output_node = chunk_result; xx = None; xx = None
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_inputs (List[Node]): chunk input node
|
||||||
|
chunk_non_compute_inputs (List[Node]): input node without chunk
|
||||||
|
chunk_outputs (Node): chunk output node
|
||||||
|
chunk_outputs_dim (int): chunk output node chunk dim
|
||||||
|
node_list (List)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
context (str): generated str
|
||||||
|
"""
|
||||||
|
chunk_outputs_name = chunk_outputs.name
|
||||||
|
chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list)
|
||||||
|
chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape
|
||||||
|
chunk_slice = _gen_chunk_slice_dim(
|
||||||
|
chunk_outputs_dim, "chunk_idx", chunk_output_shape
|
||||||
|
)
|
||||||
|
context = " chunk_result%s = %s; %s = None\n" % (
|
||||||
|
chunk_slice,
|
||||||
|
chunk_outputs_name,
|
||||||
|
chunk_outputs_name,
|
||||||
|
)
|
||||||
|
context += (
|
||||||
|
chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None"
|
||||||
|
)
|
||||||
|
|
||||||
|
# determine if its the last use for chunk input
|
||||||
|
for chunk_input in chunk_inputs + chunk_non_compute_inputs:
|
||||||
|
if all(
|
||||||
|
[
|
||||||
|
find_idx_by_name(user.name, node_list) <= chunk_outputs_idx
|
||||||
|
for user in chunk_input.users.keys()
|
||||||
|
]
|
||||||
|
):
|
||||||
|
context += "; %s = None" % chunk_input.name
|
||||||
|
|
||||||
|
context += "\n"
|
||||||
|
return context
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_name(context: str, name_from: str, name_to: str) -> str:
|
||||||
|
"""
|
||||||
|
replace node name
|
||||||
|
"""
|
||||||
|
patterns = [(" ", " "), (" ", "."), (" ", ","), ("(", ")"), ("(", ","), (" ", ")")]
|
||||||
|
for p in patterns:
|
||||||
|
source = p[0] + name_from + p[1]
|
||||||
|
target = p[0] + name_to + p[1]
|
||||||
|
if source in context:
|
||||||
|
context = context.replace(source, target)
|
||||||
|
return context
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_reshape_size(context: str, node_name: str, reshape_size_dict: Dict) -> str:
|
||||||
|
"""
|
||||||
|
replace reshape size, some may have changed due to chunk
|
||||||
|
"""
|
||||||
|
if node_name not in reshape_size_dict:
|
||||||
|
return context
|
||||||
|
for size_name, size_value in reshape_size_dict[node_name].items():
|
||||||
|
context = context.replace(size_name, size_value)
|
||||||
|
return context
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_ones_like(
|
||||||
|
search_chunk: SearchChunk,
|
||||||
|
chunk_infos: List[Dict],
|
||||||
|
region_idx: int,
|
||||||
|
node_idx: int,
|
||||||
|
node: Node,
|
||||||
|
body: List[str],
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
add chunk slice for new tensor op such as ones like
|
||||||
|
"""
|
||||||
|
if "ones_like" in node.name:
|
||||||
|
meta_node = search_chunk.trace_indice.node_list[node_idx]
|
||||||
|
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
|
||||||
|
if get_node_shape(meta_node)[chunk_dim] != 1:
|
||||||
|
source_node = meta_node.args[0].args[0]
|
||||||
|
if (
|
||||||
|
source_node not in chunk_infos[region_idx]["node_chunk_dim"]
|
||||||
|
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"]
|
||||||
|
is None
|
||||||
|
):
|
||||||
|
chunk_slice = _gen_chunk_slice_dim(
|
||||||
|
chunk_dim, "chunk_idx", get_node_shape(node)
|
||||||
|
)
|
||||||
|
body[-1] = _replace_name(
|
||||||
|
body[-1], node.args[0].name, node.args[0].name + chunk_slice
|
||||||
|
)
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_input_node(
|
||||||
|
chunk_inputs: List[Node],
|
||||||
|
region_idx: int,
|
||||||
|
chunk_inputs_dim: Dict,
|
||||||
|
node_idx: int,
|
||||||
|
body: List[str],
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
add chunk slice for input nodes
|
||||||
|
"""
|
||||||
|
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
|
||||||
|
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
|
||||||
|
if idx == node_idx:
|
||||||
|
chunk_slice = _gen_chunk_slice_dim(
|
||||||
|
dim[0], "chunk_idx", get_node_shape(input_node)
|
||||||
|
)
|
||||||
|
body[-1] = _replace_name(
|
||||||
|
body[-1], input_node.name, input_node.name + chunk_slice
|
||||||
|
)
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
def emit_code_with_chunk(
|
||||||
|
body: List[str],
|
||||||
|
nodes: Iterable[Node],
|
||||||
|
emit_node_func,
|
||||||
|
delete_unused_value_func,
|
||||||
|
search_chunk: SearchChunk,
|
||||||
|
chunk_infos: List,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Emit code with chunk according to chunk_infos.
|
||||||
|
|
||||||
|
It will generate a for loop in chunk regions, and
|
||||||
|
replace inputs and outputs of regions with chunked variables.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
body: forward code
|
||||||
|
nodes: graph.nodes
|
||||||
|
emit_node_func: function to emit node
|
||||||
|
delete_unused_value_func: function to remove the unused value
|
||||||
|
search_chunk: the class to search all chunks
|
||||||
|
chunk_infos: store all information about all chunks.
|
||||||
|
"""
|
||||||
|
node_list = list(nodes)
|
||||||
|
|
||||||
|
# chunk region
|
||||||
|
chunk_starts = [i["region"][0] for i in chunk_infos]
|
||||||
|
chunk_ends = [i["region"][1] for i in chunk_infos]
|
||||||
|
|
||||||
|
# chunk inputs
|
||||||
|
chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
|
||||||
|
chunk_inputs_non_chunk = [
|
||||||
|
i["inputs_non_chunk"] for i in chunk_infos
|
||||||
|
] # input without chunk
|
||||||
|
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
|
||||||
|
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [
|
||||||
|
j.name for i in chunk_inputs_non_chunk for j in i
|
||||||
|
]
|
||||||
|
|
||||||
|
# chunk outputs
|
||||||
|
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
|
||||||
|
chunk_outputs_dim = [i["outputs_dim"] for i in chunk_infos]
|
||||||
|
|
||||||
|
node_list = search_chunk.reorder_graph.reorder_node_list(node_list)
|
||||||
|
node_idx = 0
|
||||||
|
region_idx = 0
|
||||||
|
within_chunk_region = False
|
||||||
|
|
||||||
|
while node_idx < len(node_list):
|
||||||
|
node = node_list[node_idx]
|
||||||
|
|
||||||
|
# if is chunk start, generate for loop start
|
||||||
|
if node_idx in chunk_starts:
|
||||||
|
within_chunk_region = True
|
||||||
|
region_idx = chunk_starts.index(node_idx)
|
||||||
|
body.append(
|
||||||
|
_gen_loop_start(
|
||||||
|
chunk_inputs[region_idx],
|
||||||
|
chunk_outputs[region_idx],
|
||||||
|
chunk_outputs_dim[region_idx],
|
||||||
|
chunk_infos[region_idx]["chunk_size"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if within_chunk_region:
|
||||||
|
emit_node_func(node, body)
|
||||||
|
# replace input var with chunk var
|
||||||
|
body = _replace_input_node(
|
||||||
|
chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body
|
||||||
|
)
|
||||||
|
# ones like
|
||||||
|
body = _replace_ones_like(
|
||||||
|
search_chunk, chunk_infos, region_idx, node_idx, node, body
|
||||||
|
)
|
||||||
|
# reassgin reshape size
|
||||||
|
body[-1] = _replace_reshape_size(
|
||||||
|
body[-1], node.name, chunk_infos[region_idx]["reshape_size"]
|
||||||
|
)
|
||||||
|
body[-1] = " " + body[-1]
|
||||||
|
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||||
|
else:
|
||||||
|
emit_node_func(node, body)
|
||||||
|
if node_idx not in chunk_inputs:
|
||||||
|
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||||
|
|
||||||
|
# generate chunk region end
|
||||||
|
if node_idx in chunk_ends:
|
||||||
|
body.append(
|
||||||
|
_gen_loop_end(
|
||||||
|
chunk_inputs[region_idx],
|
||||||
|
chunk_inputs_non_chunk[region_idx],
|
||||||
|
chunk_outputs[region_idx],
|
||||||
|
chunk_outputs_dim[region_idx],
|
||||||
|
node_list,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
within_chunk_region = False
|
||||||
|
|
||||||
|
node_idx += 1
|
||||||
|
|
||||||
|
|
||||||
|
if CODEGEN_AVAILABLE:
|
||||||
|
|
||||||
|
class AutoChunkCodeGen(CodeGen):
|
||||||
|
def __init__(self, meta_graph, max_memory=None, print_mem=False):
|
||||||
|
super().__init__()
|
||||||
|
self.meta_graph = meta_graph
|
||||||
|
self.max_memory = max_memory
|
||||||
|
self.meta_node = list(meta_graph.graph.nodes)
|
||||||
|
# find the chunk regions
|
||||||
|
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem)
|
||||||
|
self.chunk_infos = self.search_chunk.search_region()
|
||||||
|
|
||||||
|
def _gen_python_code(
|
||||||
|
self, nodes, root_module: str, namespace: _Namespace
|
||||||
|
) -> PythonCode:
|
||||||
|
free_vars: List[str] = []
|
||||||
|
body: List[str] = []
|
||||||
|
globals_: Dict[str, Any] = {}
|
||||||
|
wrapped_fns: Dict[str, None] = {}
|
||||||
|
|
||||||
|
# Wrap string in list to pass by reference
|
||||||
|
maybe_return_annotation: List[str] = [""]
|
||||||
|
|
||||||
|
def add_global(name_hint: str, obj: Any):
|
||||||
|
"""Add an obj to be tracked as a global.
|
||||||
|
|
||||||
|
We call this for names that reference objects external to the
|
||||||
|
Graph, like functions or types.
|
||||||
|
|
||||||
|
Returns: the global name that should be used to reference 'obj' in generated source.
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
_is_from_torch(obj) and obj != torch.device
|
||||||
|
): # to support registering torch.device
|
||||||
|
# HACK: workaround for how torch custom ops are registered. We
|
||||||
|
# can't import them like normal modules so they must retain their
|
||||||
|
# fully qualified name.
|
||||||
|
return _get_qualified_name(obj)
|
||||||
|
|
||||||
|
# normalize the name hint to get a proper identifier
|
||||||
|
global_name = namespace.create_name(name_hint, obj)
|
||||||
|
|
||||||
|
if global_name in globals_:
|
||||||
|
assert globals_[global_name] is obj
|
||||||
|
return global_name
|
||||||
|
globals_[global_name] = obj
|
||||||
|
return global_name
|
||||||
|
|
||||||
|
# set _custom_builtins here so that we needn't import colossalai in forward
|
||||||
|
_custom_builtins["colossalai"] = _CustomBuiltin(
|
||||||
|
"import colossalai", colossalai
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pre-fill the globals table with registered builtins.
|
||||||
|
for name, (_, obj) in _custom_builtins.items():
|
||||||
|
add_global(name, obj)
|
||||||
|
|
||||||
|
def type_repr(o: Any):
|
||||||
|
if o == ():
|
||||||
|
# Empty tuple is used for empty tuple type annotation Tuple[()]
|
||||||
|
return "()"
|
||||||
|
|
||||||
|
typename = _type_repr(o)
|
||||||
|
|
||||||
|
if hasattr(o, "__origin__"):
|
||||||
|
# This is a generic type, e.g. typing.List[torch.Tensor]
|
||||||
|
origin_type = _origin_type_map.get(o.__origin__, o.__origin__)
|
||||||
|
origin_typename = add_global(_type_repr(origin_type), origin_type)
|
||||||
|
|
||||||
|
if hasattr(o, "__args__"):
|
||||||
|
# Assign global names for each of the inner type variables.
|
||||||
|
args = [type_repr(arg) for arg in o.__args__]
|
||||||
|
|
||||||
|
if len(args) == 0:
|
||||||
|
# Bare type, such as `typing.Tuple` with no subscript
|
||||||
|
# This code-path used in Python < 3.9
|
||||||
|
return origin_typename
|
||||||
|
|
||||||
|
return f'{origin_typename}[{",".join(args)}]'
|
||||||
|
else:
|
||||||
|
# Bare type, such as `typing.Tuple` with no subscript
|
||||||
|
# This code-path used in Python 3.9+
|
||||||
|
return origin_typename
|
||||||
|
|
||||||
|
# Common case: this is a regular module name like 'foo.bar.baz'
|
||||||
|
return add_global(typename, o)
|
||||||
|
|
||||||
|
def _format_args(
|
||||||
|
args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
|
||||||
|
) -> str:
|
||||||
|
def _get_repr(arg):
|
||||||
|
# Handle NamedTuples (if it has `_fields`) via add_global.
|
||||||
|
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
|
||||||
|
qualified_name = _get_qualified_name(type(arg))
|
||||||
|
global_name = add_global(qualified_name, type(arg))
|
||||||
|
return f"{global_name}{repr(tuple(arg))}"
|
||||||
|
return repr(arg)
|
||||||
|
|
||||||
|
args_s = ", ".join(_get_repr(a) for a in args)
|
||||||
|
kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items())
|
||||||
|
if args_s and kwargs_s:
|
||||||
|
return f"{args_s}, {kwargs_s}"
|
||||||
|
return args_s or kwargs_s
|
||||||
|
|
||||||
|
# Run through reverse nodes and record the first instance of a use
|
||||||
|
# of a given node. This represents the *last* use of the node in the
|
||||||
|
# execution order of the program, which we will use to free unused
|
||||||
|
# values
|
||||||
|
node_to_last_use: Dict[Node, Node] = {}
|
||||||
|
user_to_last_uses: Dict[Node, List[Node]] = {}
|
||||||
|
|
||||||
|
def register_last_uses(n: Node, user: Node):
|
||||||
|
if n not in node_to_last_use:
|
||||||
|
node_to_last_use[n] = user
|
||||||
|
user_to_last_uses.setdefault(user, []).append(n)
|
||||||
|
|
||||||
|
for node in reversed(nodes):
|
||||||
|
map_arg(node.args, lambda n: register_last_uses(n, node))
|
||||||
|
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||||
|
|
||||||
|
delete_free_var_from_last_use(user_to_last_uses)
|
||||||
|
|
||||||
|
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||||
|
def delete_unused_values(user: Node, body, to_keep=[]):
|
||||||
|
"""
|
||||||
|
Delete values after their last use. This ensures that values that are
|
||||||
|
not used in the remainder of the code are freed and the memory usage
|
||||||
|
of the code is optimal.
|
||||||
|
"""
|
||||||
|
if user.op == "placeholder":
|
||||||
|
return
|
||||||
|
if user.op == "output":
|
||||||
|
body.append("\n")
|
||||||
|
return
|
||||||
|
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||||
|
nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep]
|
||||||
|
if len(nodes_to_delete):
|
||||||
|
to_delete_str = " = ".join(
|
||||||
|
[repr(n) for n in nodes_to_delete] + ["None"]
|
||||||
|
)
|
||||||
|
body.append(f"; {to_delete_str}\n")
|
||||||
|
else:
|
||||||
|
body.append("\n")
|
||||||
|
|
||||||
|
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||||
|
def emit_node(node: Node, body):
|
||||||
|
maybe_type_annotation = (
|
||||||
|
"" if node.type is None else f" : {type_repr(node.type)}"
|
||||||
|
)
|
||||||
|
if node.op == "placeholder":
|
||||||
|
assert isinstance(node.target, str)
|
||||||
|
maybe_default_arg = (
|
||||||
|
"" if not node.args else f" = {repr(node.args[0])}"
|
||||||
|
)
|
||||||
|
free_vars.append(
|
||||||
|
f"{node.target}{maybe_type_annotation}{maybe_default_arg}"
|
||||||
|
)
|
||||||
|
raw_name = node.target.replace("*", "")
|
||||||
|
if raw_name != repr(node):
|
||||||
|
body.append(f"{repr(node)} = {raw_name}\n")
|
||||||
|
return
|
||||||
|
elif node.op == "call_method":
|
||||||
|
assert isinstance(node.target, str)
|
||||||
|
body.append(
|
||||||
|
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
|
||||||
|
f"({_format_args(node.args[1:], node.kwargs)})"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
elif node.op == "call_function":
|
||||||
|
assert callable(node.target)
|
||||||
|
# pretty print operators
|
||||||
|
if (
|
||||||
|
node.target.__module__ == "_operator"
|
||||||
|
and node.target.__name__ in magic_methods
|
||||||
|
):
|
||||||
|
assert isinstance(node.args, tuple)
|
||||||
|
body.append(
|
||||||
|
f"{repr(node)}{maybe_type_annotation} = "
|
||||||
|
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# pretty print inplace operators; required for jit.script to work properly
|
||||||
|
# not currently supported in normal FX graphs, but generated by torchdynamo
|
||||||
|
if (
|
||||||
|
node.target.__module__ == "_operator"
|
||||||
|
and node.target.__name__ in inplace_methods
|
||||||
|
):
|
||||||
|
body.append(
|
||||||
|
f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
|
||||||
|
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
qualified_name = _get_qualified_name(node.target)
|
||||||
|
global_name = add_global(qualified_name, node.target)
|
||||||
|
# special case for getattr: node.args could be 2-argument or 3-argument
|
||||||
|
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
|
||||||
|
if (
|
||||||
|
global_name == "getattr"
|
||||||
|
and isinstance(node.args, tuple)
|
||||||
|
and isinstance(node.args[1], str)
|
||||||
|
and node.args[1].isidentifier()
|
||||||
|
and len(node.args) == 2
|
||||||
|
):
|
||||||
|
body.append(
|
||||||
|
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
body.append(
|
||||||
|
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
|
||||||
|
)
|
||||||
|
if node.meta.get("is_wrapped", False):
|
||||||
|
wrapped_fns.setdefault(global_name)
|
||||||
|
return
|
||||||
|
elif node.op == "call_module":
|
||||||
|
assert isinstance(node.target, str)
|
||||||
|
body.append(
|
||||||
|
f"{repr(node)}{maybe_type_annotation} = "
|
||||||
|
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
elif node.op == "get_attr":
|
||||||
|
assert isinstance(node.target, str)
|
||||||
|
body.append(
|
||||||
|
f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
elif node.op == "output":
|
||||||
|
if node.type is not None:
|
||||||
|
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
|
||||||
|
body.append(self.generate_output(node.args[0]))
|
||||||
|
return
|
||||||
|
raise NotImplementedError(f"node: {node.op} {node.target}")
|
||||||
|
|
||||||
|
# Modified for activation checkpointing
|
||||||
|
ckpt_func = []
|
||||||
|
|
||||||
|
# if any node has a list of labels for activation_checkpoint, we
|
||||||
|
# will use nested type of activation checkpoint codegen
|
||||||
|
emit_code_with_chunk(
|
||||||
|
body,
|
||||||
|
nodes,
|
||||||
|
emit_node,
|
||||||
|
delete_unused_values,
|
||||||
|
self.search_chunk,
|
||||||
|
self.chunk_infos,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(body) == 0:
|
||||||
|
# If the Graph has no non-placeholder nodes, no lines for the body
|
||||||
|
# have been emitted. To continue to have valid Python code, emit a
|
||||||
|
# single pass statement
|
||||||
|
body.append("pass\n")
|
||||||
|
|
||||||
|
if len(wrapped_fns) > 0:
|
||||||
|
wrap_name = add_global("wrap", torch.fx.wrap)
|
||||||
|
wrap_stmts = "\n".join(
|
||||||
|
[f'{wrap_name}("{name}")' for name in wrapped_fns]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
wrap_stmts = ""
|
||||||
|
|
||||||
|
if self._body_transformer:
|
||||||
|
body = self._body_transformer(body)
|
||||||
|
|
||||||
|
for name, value in self.additional_globals():
|
||||||
|
add_global(name, value)
|
||||||
|
|
||||||
|
# as we need colossalai.utils.checkpoint, we need to import colossalai
|
||||||
|
# in forward function
|
||||||
|
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
|
||||||
|
prologue = "".join(ckpt_func) + prologue
|
||||||
|
prologue = prologue
|
||||||
|
|
||||||
|
code = "".join(body)
|
||||||
|
code = "\n".join(" " + line for line in code.split("\n"))
|
||||||
|
fn_code = f"""
|
||||||
|
{wrap_stmts}
|
||||||
|
|
||||||
|
{prologue}
|
||||||
|
{code}"""
|
||||||
|
# print(fn_code)
|
||||||
|
return PythonCode(fn_code, globals_)
|
|
@ -0,0 +1,328 @@
|
||||||
|
import copy
|
||||||
|
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.fx.node import Node, map_arg
|
||||||
|
|
||||||
|
from colossalai.fx.profiler import activation_size, parameter_size
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
delete_free_var_from_last_use,
|
||||||
|
find_idx_by_name,
|
||||||
|
get_node_shape,
|
||||||
|
is_non_compute_node_except_placeholder,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EstimateMemory(object):
|
||||||
|
"""
|
||||||
|
Estimate memory with chunk
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _get_meta_node_size(self, x):
|
||||||
|
x = x.meta["tensor_meta"]
|
||||||
|
x = x.numel * torch.tensor([], dtype=x.dtype).element_size()
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _get_output_node(self, n):
|
||||||
|
out_size = activation_size(n.meta["fwd_out"])
|
||||||
|
out_node = [n.name] if out_size > 0 else []
|
||||||
|
return out_size, out_node
|
||||||
|
|
||||||
|
def _get_output_node_size(self, n):
|
||||||
|
return self._get_output_node(n)[0]
|
||||||
|
|
||||||
|
def _add_active_node(self, n, active_list):
|
||||||
|
new_active = self._get_output_node(n)[1]
|
||||||
|
if n.op == "placeholder":
|
||||||
|
new_active.append(n.name)
|
||||||
|
for i in new_active:
|
||||||
|
if i not in active_list:
|
||||||
|
active_list.append(i)
|
||||||
|
|
||||||
|
def _get_delete_node(self, user, user_to_last_uses, to_keep=None):
|
||||||
|
delete_size = 0
|
||||||
|
delete_node = []
|
||||||
|
if user.op not in ("output",):
|
||||||
|
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||||
|
if to_keep is not None:
|
||||||
|
keep_list = []
|
||||||
|
for n in nodes_to_delete:
|
||||||
|
if n.name in to_keep:
|
||||||
|
keep_list.append(n)
|
||||||
|
for n in keep_list:
|
||||||
|
if n in nodes_to_delete:
|
||||||
|
nodes_to_delete.remove(n)
|
||||||
|
if len(nodes_to_delete):
|
||||||
|
out_node = [self._get_output_node(i) for i in nodes_to_delete]
|
||||||
|
delete_size = sum([i[0] for i in out_node])
|
||||||
|
for i in range(len(out_node)):
|
||||||
|
if out_node[i][0] > 0:
|
||||||
|
delete_node.append(out_node[i][1][0])
|
||||||
|
elif nodes_to_delete[i].op == "placeholder":
|
||||||
|
delete_node.append(nodes_to_delete[i].name)
|
||||||
|
# elif any(j in nodes_to_delete[i].name for j in ['transpose', 'permute', 'view']):
|
||||||
|
# delete_node.append(nodes_to_delete[i].name)
|
||||||
|
return delete_size, delete_node
|
||||||
|
|
||||||
|
def _get_delete_node_size(self, user, user_to_last_uses, to_keep):
|
||||||
|
return self._get_delete_node(user, user_to_last_uses, to_keep)[0]
|
||||||
|
|
||||||
|
def _remove_deactive_node(self, user, user_to_last_uses, active_list):
|
||||||
|
delete_node = self._get_delete_node(user, user_to_last_uses)[1]
|
||||||
|
for i in delete_node:
|
||||||
|
if i in active_list:
|
||||||
|
active_list.remove(i)
|
||||||
|
|
||||||
|
def _get_chunk_inputs_size(
|
||||||
|
self, chunk_inputs, chunk_inputs_non_chunk, node_list, chunk_end_idx
|
||||||
|
):
|
||||||
|
nodes_to_delete = []
|
||||||
|
for chunk_input in chunk_inputs + chunk_inputs_non_chunk:
|
||||||
|
chunk_input_users = chunk_input.users.keys()
|
||||||
|
chunk_input_users_idx = [
|
||||||
|
find_idx_by_name(i.name, node_list) for i in chunk_input_users
|
||||||
|
]
|
||||||
|
if all(i <= chunk_end_idx for i in chunk_input_users_idx):
|
||||||
|
if chunk_input not in nodes_to_delete:
|
||||||
|
nodes_to_delete.append(chunk_input)
|
||||||
|
out_node = [self._get_output_node(i) for i in nodes_to_delete]
|
||||||
|
delete_size = sum([i[0] for i in out_node])
|
||||||
|
return delete_size
|
||||||
|
|
||||||
|
def _get_last_usr(self, nodes):
|
||||||
|
node_to_last_use: Dict[Node, Node] = {}
|
||||||
|
user_to_last_uses: Dict[Node, List[Node]] = {}
|
||||||
|
|
||||||
|
def register_last_uses(n: Node, user: Node):
|
||||||
|
if n not in node_to_last_use:
|
||||||
|
node_to_last_use[n] = user
|
||||||
|
user_to_last_uses.setdefault(user, []).append(n)
|
||||||
|
|
||||||
|
for node in reversed(nodes):
|
||||||
|
map_arg(node.args, lambda n: register_last_uses(n, node))
|
||||||
|
map_arg(node.kwargs, lambda n: register_last_uses(n, node))
|
||||||
|
return user_to_last_uses
|
||||||
|
|
||||||
|
def _get_contiguous_memory(self, node, not_contiguous_list, delete=False):
|
||||||
|
mem = 0
|
||||||
|
not_contiguous_ops = ["permute"]
|
||||||
|
inherit_contiguous_ops = ["transpose", "view"]
|
||||||
|
|
||||||
|
if node.op == "call_function" and any(
|
||||||
|
n in node.name for n in ["matmul", "reshape"]
|
||||||
|
):
|
||||||
|
for n in node.args:
|
||||||
|
if n in not_contiguous_list:
|
||||||
|
# matmul won't change origin tensor, but create a tmp copy
|
||||||
|
mem += self._get_output_node_size(n)
|
||||||
|
elif node.op == "call_module":
|
||||||
|
for n in node.args:
|
||||||
|
if n in not_contiguous_list:
|
||||||
|
# module will just make origin tensor to contiguous
|
||||||
|
if delete:
|
||||||
|
not_contiguous_list.remove(n)
|
||||||
|
elif node.op == "call_method" and any(
|
||||||
|
i in node.name for i in not_contiguous_ops
|
||||||
|
):
|
||||||
|
if node not in not_contiguous_list:
|
||||||
|
not_contiguous_list.append(node)
|
||||||
|
return mem
|
||||||
|
|
||||||
|
def _get_chunk_ratio(self, node, chunk_node_dim, chunk_size):
|
||||||
|
if node not in chunk_node_dim:
|
||||||
|
return 1.0
|
||||||
|
node_shape = get_node_shape(node)
|
||||||
|
chunk_dim = chunk_node_dim[node]["chunk_dim"]
|
||||||
|
if chunk_dim is None:
|
||||||
|
return 1.0
|
||||||
|
else:
|
||||||
|
return float(chunk_size) / node_shape[chunk_dim]
|
||||||
|
|
||||||
|
def _get_chunk_delete_node_size(
|
||||||
|
self, user, user_to_last_uses, chunk_ratio, chunk_inputs_names
|
||||||
|
):
|
||||||
|
# if any(j in user.name for j in ['transpose', 'permute', 'view']):
|
||||||
|
# return 0
|
||||||
|
if user.op in ("placeholder", "output"):
|
||||||
|
return 0
|
||||||
|
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||||
|
delete_size = 0
|
||||||
|
for n in nodes_to_delete:
|
||||||
|
if n.name in chunk_inputs_names:
|
||||||
|
continue
|
||||||
|
delete_size += self._get_output_node_size(n) * chunk_ratio
|
||||||
|
return delete_size
|
||||||
|
|
||||||
|
def _print_mem_log(self, log, nodes, title=None):
|
||||||
|
if title:
|
||||||
|
print(title)
|
||||||
|
for idx, (l, n) in enumerate(zip(log, nodes)):
|
||||||
|
print("%s:%.2f \t" % (n.name, l), end="")
|
||||||
|
if (idx + 1) % 3 == 0:
|
||||||
|
print("")
|
||||||
|
print("\n")
|
||||||
|
|
||||||
|
def _print_compute_op_mem_log(self, log, nodes, title=None):
|
||||||
|
if title:
|
||||||
|
print(title)
|
||||||
|
for idx, (l, n) in enumerate(zip(log, nodes)):
|
||||||
|
if n.op in ["placeholder", "get_attr", "output"]:
|
||||||
|
continue
|
||||||
|
if any(i in n.name for i in ["getitem", "getattr"]):
|
||||||
|
continue
|
||||||
|
print("%s:%.2f \t" % (n.name, l), end="")
|
||||||
|
if (idx + 1) % 3 == 0:
|
||||||
|
print("")
|
||||||
|
print("\n")
|
||||||
|
|
||||||
|
def estimate_chunk_inference_mem(
|
||||||
|
self,
|
||||||
|
node_list: List,
|
||||||
|
chunk_infos=None,
|
||||||
|
print_mem=False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Estimate inference memory with chunk
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_list (List): _description_
|
||||||
|
chunk_infos (Dict): Chunk information. Defaults to None.
|
||||||
|
print_mem (bool): Wether to print peak memory of every node. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
act_memory_peak_log (List): peak memory of every node
|
||||||
|
act_memory_after_node_log (List): memory after excuting every node
|
||||||
|
active_node_list_log (List): active nodes of every node. active nodes refer to
|
||||||
|
nodes generated but not deleted.
|
||||||
|
"""
|
||||||
|
act_memory = 0.0
|
||||||
|
act_memory_peak_log = []
|
||||||
|
act_memory_after_node_log = []
|
||||||
|
active_node_list = []
|
||||||
|
active_node_list_log = []
|
||||||
|
not_contiguous_list = []
|
||||||
|
user_to_last_uses = self._get_last_usr(node_list)
|
||||||
|
user_to_last_uses_no_free_var = self._get_last_usr(node_list)
|
||||||
|
delete_free_var_from_last_use(user_to_last_uses_no_free_var)
|
||||||
|
|
||||||
|
use_chunk = True if chunk_infos is not None else False
|
||||||
|
chunk_within = False
|
||||||
|
chunk_region_idx = None
|
||||||
|
chunk_ratio = 1 # use it to estimate chunk mem
|
||||||
|
chunk_inputs_names = []
|
||||||
|
|
||||||
|
if use_chunk:
|
||||||
|
chunk_regions = [i["region"] for i in chunk_infos]
|
||||||
|
chunk_starts = [i[0] for i in chunk_regions]
|
||||||
|
chunk_ends = [i[1] for i in chunk_regions]
|
||||||
|
chunk_inputs = [i["inputs"] for i in chunk_infos]
|
||||||
|
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos]
|
||||||
|
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [
|
||||||
|
j.name for i in chunk_inputs_non_chunk for j in i
|
||||||
|
]
|
||||||
|
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
|
||||||
|
chunk_node_dim = [i["node_chunk_dim"] for i in chunk_infos]
|
||||||
|
chunk_sizes = [
|
||||||
|
i["chunk_size"] if "chunk_size" in i else 1 for i in chunk_infos
|
||||||
|
]
|
||||||
|
|
||||||
|
for idx, node in enumerate(node_list):
|
||||||
|
# if node in chunk start nodes, change chunk ratio and add chunk_tensor
|
||||||
|
if use_chunk and idx in chunk_starts:
|
||||||
|
chunk_within = True
|
||||||
|
chunk_region_idx = chunk_starts.index(idx)
|
||||||
|
act_memory += self._get_output_node_size(
|
||||||
|
chunk_outputs[chunk_region_idx]
|
||||||
|
) / (1024**2)
|
||||||
|
|
||||||
|
# determine chunk ratio for current node
|
||||||
|
if chunk_within:
|
||||||
|
chunk_ratio = self._get_chunk_ratio(
|
||||||
|
node,
|
||||||
|
chunk_node_dim[chunk_region_idx],
|
||||||
|
chunk_sizes[chunk_region_idx],
|
||||||
|
)
|
||||||
|
|
||||||
|
# if node is placeholder, just add the size of the node
|
||||||
|
if node.op == "placeholder":
|
||||||
|
act_memory += self._get_meta_node_size(node) * chunk_ratio / (1024**2)
|
||||||
|
act_memory_peak_log.append(act_memory)
|
||||||
|
# skip output
|
||||||
|
elif node.op == "output":
|
||||||
|
continue
|
||||||
|
# no change for non compute node
|
||||||
|
elif is_non_compute_node_except_placeholder(node):
|
||||||
|
act_memory_peak_log.append(act_memory)
|
||||||
|
# node is a compute op
|
||||||
|
# calculate tmp, output node and delete node memory
|
||||||
|
else:
|
||||||
|
# forward memory
|
||||||
|
# TODO: contiguous_memory still not accurate for matmul, view, reshape and transpose
|
||||||
|
act_memory += (
|
||||||
|
self._get_contiguous_memory(node, not_contiguous_list)
|
||||||
|
* chunk_ratio
|
||||||
|
/ (1024**2)
|
||||||
|
)
|
||||||
|
act_memory += (
|
||||||
|
self._get_output_node_size(node) * chunk_ratio / (1024**2)
|
||||||
|
)
|
||||||
|
# record max act memory
|
||||||
|
act_memory_peak_log.append(act_memory)
|
||||||
|
# delete useless memory
|
||||||
|
act_memory -= (
|
||||||
|
self._get_contiguous_memory(node, not_contiguous_list, delete=True)
|
||||||
|
* chunk_ratio
|
||||||
|
/ (1024**2)
|
||||||
|
)
|
||||||
|
# delete unused vars not in chunk_input_list
|
||||||
|
# we can't delete input nodes until chunk ends
|
||||||
|
if chunk_within:
|
||||||
|
act_memory -= self._get_chunk_delete_node_size(
|
||||||
|
node,
|
||||||
|
user_to_last_uses_no_free_var,
|
||||||
|
chunk_ratio,
|
||||||
|
chunk_inputs_names,
|
||||||
|
) / (1024**2)
|
||||||
|
else:
|
||||||
|
act_memory -= self._get_delete_node_size(
|
||||||
|
node, user_to_last_uses_no_free_var, chunk_inputs_names
|
||||||
|
) / (1024**2)
|
||||||
|
|
||||||
|
# log active node, only effective without chunk
|
||||||
|
self._add_active_node(node, active_node_list)
|
||||||
|
self._remove_deactive_node(node, user_to_last_uses, active_node_list)
|
||||||
|
|
||||||
|
# if node in chunk end nodes, restore chunk settings
|
||||||
|
if use_chunk and idx in chunk_ends:
|
||||||
|
act_memory -= (
|
||||||
|
self._get_output_node_size(node) * chunk_ratio / (1024**2)
|
||||||
|
)
|
||||||
|
act_memory -= self._get_chunk_inputs_size(
|
||||||
|
chunk_inputs[chunk_region_idx],
|
||||||
|
chunk_inputs_non_chunk[chunk_region_idx],
|
||||||
|
node_list,
|
||||||
|
chunk_regions[chunk_region_idx][1],
|
||||||
|
) / (1024**2)
|
||||||
|
chunk_within = False
|
||||||
|
chunk_ratio = 1
|
||||||
|
chunk_region_idx = None
|
||||||
|
|
||||||
|
act_memory_after_node_log.append(act_memory)
|
||||||
|
active_node_list_log.append(copy.deepcopy(active_node_list))
|
||||||
|
|
||||||
|
if print_mem:
|
||||||
|
print("with chunk" if use_chunk else "without chunk")
|
||||||
|
# self._print_mem_log(act_memory_peak_log, node_list, "peak")
|
||||||
|
# self._print_mem_log(act_memory_after_node_log, node_list, "after")
|
||||||
|
self._print_compute_op_mem_log(act_memory_peak_log, node_list, "peak")
|
||||||
|
# self._print_compute_op_mem_log(
|
||||||
|
# act_memory_after_node_log, node_list, "after"
|
||||||
|
# )
|
||||||
|
|
||||||
|
# param_memory = parameter_size(gm)
|
||||||
|
# all_memory = act_memory + param_memory
|
||||||
|
return act_memory_peak_log, act_memory_after_node_log, active_node_list_log
|
|
@ -0,0 +1,117 @@
|
||||||
|
from .trace_indice import TraceIndice
|
||||||
|
from .utils import find_idx_by_name
|
||||||
|
|
||||||
|
|
||||||
|
class ReorderGraph(object):
|
||||||
|
"""
|
||||||
|
Reorder node list and indice trace list
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, trace_indice: TraceIndice) -> None:
|
||||||
|
self.trace_indice = trace_indice
|
||||||
|
self.all_reorder_map = {
|
||||||
|
i: i for i in range(len(self.trace_indice.indice_trace_list))
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_reorder_map(self, chunk_info):
|
||||||
|
reorder_map = {i: i for i in range(len(self.trace_indice.node_list))}
|
||||||
|
|
||||||
|
chunk_region_start = chunk_info["region"][0]
|
||||||
|
chunk_region_end = chunk_info["region"][1]
|
||||||
|
chunk_prepose_nodes = chunk_info["args"]["prepose_nodes"]
|
||||||
|
chunk_prepose_nodes_idx = [
|
||||||
|
find_idx_by_name(i.name, self.trace_indice.node_list)
|
||||||
|
for i in chunk_prepose_nodes
|
||||||
|
]
|
||||||
|
# put prepose nodes ahead
|
||||||
|
for idx, n in enumerate(chunk_prepose_nodes):
|
||||||
|
n_idx = chunk_prepose_nodes_idx[idx]
|
||||||
|
reorder_map[n_idx] = chunk_region_start + idx
|
||||||
|
# put other nodes after prepose nodes
|
||||||
|
for n in self.trace_indice.node_list[chunk_region_start : chunk_region_end + 1]:
|
||||||
|
if n in chunk_prepose_nodes:
|
||||||
|
continue
|
||||||
|
n_idx = find_idx_by_name(n.name, self.trace_indice.node_list)
|
||||||
|
pos = sum([n_idx < i for i in chunk_prepose_nodes_idx])
|
||||||
|
reorder_map[n_idx] = n_idx + pos
|
||||||
|
|
||||||
|
return reorder_map
|
||||||
|
|
||||||
|
def _reorder_chunk_info(self, chunk_info, reorder_map):
|
||||||
|
# update chunk info
|
||||||
|
chunk_info["region"] = (
|
||||||
|
chunk_info["region"][0] + len(chunk_info["args"]["prepose_nodes"]),
|
||||||
|
chunk_info["region"][1],
|
||||||
|
)
|
||||||
|
new_inputs_dim = []
|
||||||
|
for idx, input_dim in enumerate(chunk_info["inputs_dim"]):
|
||||||
|
new_input_dim = {}
|
||||||
|
for k, v in input_dim.items():
|
||||||
|
new_input_dim[reorder_map[k]] = v
|
||||||
|
new_inputs_dim.append(new_input_dim)
|
||||||
|
chunk_info["inputs_dim"] = new_inputs_dim
|
||||||
|
return chunk_info
|
||||||
|
|
||||||
|
def _update_all_reorder_map(self, reorder_map):
|
||||||
|
for origin_idx, map_idx in self.all_reorder_map.items():
|
||||||
|
self.all_reorder_map[origin_idx] = reorder_map[map_idx]
|
||||||
|
|
||||||
|
def _reorder_self_node_list(self, reorder_map):
|
||||||
|
new_node_list = [None for _ in range(len(self.trace_indice.node_list))]
|
||||||
|
for old_idx, new_idx in reorder_map.items():
|
||||||
|
new_node_list[new_idx] = self.trace_indice.node_list[old_idx]
|
||||||
|
self.trace_indice.node_list = new_node_list
|
||||||
|
|
||||||
|
def _reorder_idx_trace(self, reorder_map):
|
||||||
|
# reorder list
|
||||||
|
new_idx_trace_list = [
|
||||||
|
None for _ in range(len(self.trace_indice.indice_trace_list))
|
||||||
|
]
|
||||||
|
for old_idx, new_idx in reorder_map.items():
|
||||||
|
new_idx_trace_list[new_idx] = self.trace_indice.indice_trace_list[old_idx]
|
||||||
|
self.trace_indice.indice_trace_list = new_idx_trace_list
|
||||||
|
# update compute
|
||||||
|
for idx_trace in self.trace_indice.indice_trace_list:
|
||||||
|
compute = idx_trace["compute"]
|
||||||
|
for dim_compute in compute:
|
||||||
|
for idx, i in enumerate(dim_compute):
|
||||||
|
dim_compute[idx] = reorder_map[i]
|
||||||
|
# update source
|
||||||
|
for idx_trace in self.trace_indice.indice_trace_list:
|
||||||
|
source = idx_trace["source"]
|
||||||
|
for dim_idx, dim_source in enumerate(source):
|
||||||
|
new_dim_source = {}
|
||||||
|
for k, v in dim_source.items():
|
||||||
|
new_dim_source[reorder_map[k]] = v
|
||||||
|
source[dim_idx] = new_dim_source
|
||||||
|
|
||||||
|
def reorder_all(self, chunk_info):
|
||||||
|
if chunk_info is None:
|
||||||
|
return chunk_info
|
||||||
|
if len(chunk_info["args"]["prepose_nodes"]) == 0:
|
||||||
|
return chunk_info
|
||||||
|
reorder_map = self._get_reorder_map(chunk_info)
|
||||||
|
self._update_all_reorder_map(reorder_map)
|
||||||
|
self._reorder_idx_trace(reorder_map)
|
||||||
|
self._reorder_self_node_list(reorder_map)
|
||||||
|
chunk_info = self._reorder_chunk_info(chunk_info, reorder_map)
|
||||||
|
return chunk_info
|
||||||
|
|
||||||
|
def reorder_node_list(self, node_list):
|
||||||
|
new_node_list = [None for _ in range(len(node_list))]
|
||||||
|
for old_idx, new_idx in self.all_reorder_map.items():
|
||||||
|
new_node_list[new_idx] = node_list[old_idx]
|
||||||
|
return new_node_list
|
||||||
|
|
||||||
|
def tmp_reorder(self, node_list, chunk_info):
|
||||||
|
if len(chunk_info["args"]["prepose_nodes"]) == 0:
|
||||||
|
return node_list, chunk_info
|
||||||
|
reorder_map = self._get_reorder_map(chunk_info)
|
||||||
|
|
||||||
|
# new tmp node list
|
||||||
|
new_node_list = [None for _ in range(len(node_list))]
|
||||||
|
for old_idx, new_idx in reorder_map.items():
|
||||||
|
new_node_list[new_idx] = node_list[old_idx]
|
||||||
|
|
||||||
|
chunk_info = self._reorder_chunk_info(chunk_info, reorder_map)
|
||||||
|
return new_node_list, chunk_info
|
|
@ -0,0 +1,319 @@
|
||||||
|
import copy
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
from torch.fx.node import Node
|
||||||
|
|
||||||
|
from .estimate_memory import EstimateMemory
|
||||||
|
from .reorder_graph import ReorderGraph
|
||||||
|
from .select_chunk import SelectChunk
|
||||||
|
from .trace_flow import TraceFlow
|
||||||
|
from .trace_indice import TraceIndice
|
||||||
|
from .utils import (
|
||||||
|
get_node_shape,
|
||||||
|
is_non_compute_node,
|
||||||
|
is_non_compute_node_except_placeholder,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchChunk(object):
|
||||||
|
"""
|
||||||
|
This is the core class for AutoChunk.
|
||||||
|
|
||||||
|
It defines the framework of the strategy of AutoChunk.
|
||||||
|
Chunks will be selected one by one utill search stops.
|
||||||
|
|
||||||
|
The chunk search is as follows:
|
||||||
|
1. find the peak memory node
|
||||||
|
2. find the max chunk region according to the peak memory node
|
||||||
|
3. find all possible chunk regions in the max chunk region
|
||||||
|
4. find the best chunk region for current status
|
||||||
|
5. goto 1
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
gm: graph model
|
||||||
|
print_mem (bool): print estimated memory
|
||||||
|
trace_index: trace the flow of every dim of every node to find all free dims
|
||||||
|
trace_flow: determine the region chunk strategy
|
||||||
|
reorder_graph: reorder nodes to improve chunk efficiency
|
||||||
|
estimate_memory: estimate memory with chunk
|
||||||
|
select_chunk: select the best chunk region
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gm: graph model
|
||||||
|
max_memory (int): max memory in MB
|
||||||
|
print_mem (bool): print estimated memory
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, gm, max_memory=None, print_mem=False) -> None:
|
||||||
|
self.gm = gm
|
||||||
|
self.print_mem = print_mem
|
||||||
|
self.trace_indice = TraceIndice(list(gm.graph.nodes))
|
||||||
|
self.trace_indice.trace_indice()
|
||||||
|
self.trace_flow = TraceFlow(self.trace_indice)
|
||||||
|
self.reorder_graph = ReorderGraph(self.trace_indice)
|
||||||
|
self.estimate_memory = EstimateMemory()
|
||||||
|
self.select_chunk = SelectChunk(
|
||||||
|
self.trace_indice,
|
||||||
|
self.estimate_memory,
|
||||||
|
self.reorder_graph,
|
||||||
|
max_memory=max_memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _find_peak_node(self, mem_peak):
|
||||||
|
max_value = max(mem_peak)
|
||||||
|
max_idx = mem_peak.index(max_value)
|
||||||
|
return max_idx
|
||||||
|
|
||||||
|
def _get_free_var_idx(self) -> List:
|
||||||
|
"""
|
||||||
|
Get free var index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
free_var_idx (List): all indexs of free vars
|
||||||
|
"""
|
||||||
|
free_var_idx = []
|
||||||
|
for idx, n in enumerate(self.trace_indice.node_list):
|
||||||
|
if n.op == "placeholder":
|
||||||
|
free_var_idx.append(idx)
|
||||||
|
return free_var_idx
|
||||||
|
|
||||||
|
def _search_max_chunk_region(
|
||||||
|
self, active_node: List, peak_node: Node, chunk_regions: List
|
||||||
|
) -> Tuple:
|
||||||
|
"""
|
||||||
|
Search max chunk region according to peak memory node
|
||||||
|
|
||||||
|
Chunk region starts extending from the peak node, stops where free var num is min
|
||||||
|
|
||||||
|
Args:
|
||||||
|
active_node (List): active node status for every node
|
||||||
|
peak_node (Node): peak memory node
|
||||||
|
chunk_regions (List): chunk region infos
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
chunk_region_start (int)
|
||||||
|
chunk_region_end (int)
|
||||||
|
"""
|
||||||
|
free_vars = self._get_free_var_idx()
|
||||||
|
free_var_num = len(free_vars)
|
||||||
|
active_node_num = [len(i) for i in active_node]
|
||||||
|
min_active_node_num = min(active_node_num[free_var_num:])
|
||||||
|
threshold = max(free_var_num, min_active_node_num)
|
||||||
|
|
||||||
|
# from peak_node to free_var
|
||||||
|
inside_flag = False
|
||||||
|
chunk_region_start = free_var_num
|
||||||
|
for i in range(peak_node, -1, -1):
|
||||||
|
if active_node_num[i] <= threshold:
|
||||||
|
inside_flag = True
|
||||||
|
if inside_flag and active_node_num[i] > threshold:
|
||||||
|
chunk_region_start = i + 1
|
||||||
|
break
|
||||||
|
|
||||||
|
# from peak_node to len-2
|
||||||
|
inside_flag = False
|
||||||
|
chunk_region_end = len(active_node) - 1
|
||||||
|
for i in range(peak_node, len(active_node)):
|
||||||
|
if active_node_num[i] <= threshold:
|
||||||
|
inside_flag = True
|
||||||
|
if inside_flag and active_node_num[i] > threshold:
|
||||||
|
chunk_region_end = i
|
||||||
|
break
|
||||||
|
|
||||||
|
for i in chunk_regions:
|
||||||
|
region = i["region"]
|
||||||
|
if chunk_region_start >= region[0] and chunk_region_end <= region[1]:
|
||||||
|
return None
|
||||||
|
elif (
|
||||||
|
region[0] <= chunk_region_start <= region[1]
|
||||||
|
and chunk_region_end > region[1]
|
||||||
|
):
|
||||||
|
chunk_region_start = region[1] + 1
|
||||||
|
elif (
|
||||||
|
region[0] <= chunk_region_end <= region[1]
|
||||||
|
and chunk_region_start < region[0]
|
||||||
|
):
|
||||||
|
chunk_region_end = region[0] - 1
|
||||||
|
return chunk_region_start, chunk_region_end
|
||||||
|
|
||||||
|
def _find_chunk_info(self, input_trace, output_trace, start_idx, end_idx) -> List:
|
||||||
|
"""
|
||||||
|
Find chunk info for a region.
|
||||||
|
|
||||||
|
We are given the region start and region end, and need to find out all chunk info for it.
|
||||||
|
We first loop every dim of start node and end node, to see if we can find dim pair,
|
||||||
|
which is linked in a flow and not computed.
|
||||||
|
If found, we then search flow in the whole region to find out all chunk infos.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_trace (List): node's input trace in region
|
||||||
|
output_trace (List): node's output trace in region
|
||||||
|
start_idx (int): region start node index
|
||||||
|
end_idx (int): region end node index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
chunk_infos: possible regions found
|
||||||
|
"""
|
||||||
|
start_traces = input_trace[start_idx]
|
||||||
|
end_trace = output_trace[end_idx]
|
||||||
|
end_node = self.trace_indice.node_list[end_idx]
|
||||||
|
chunk_infos = []
|
||||||
|
for end_dim, _ in enumerate(end_trace["indice"]):
|
||||||
|
if len(start_traces) > 1:
|
||||||
|
continue
|
||||||
|
for start_node, start_trace in start_traces.items():
|
||||||
|
for start_dim, _ in enumerate(start_trace["indice"]):
|
||||||
|
# dim size cannot be 1
|
||||||
|
if (
|
||||||
|
get_node_shape(end_node)[end_dim] == 1
|
||||||
|
or get_node_shape(start_node)[start_dim] == 1
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
# check index source align
|
||||||
|
if not self.trace_flow.check_index_source(
|
||||||
|
start_dim, start_node, start_idx, end_dim, end_node
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
# check index copmute
|
||||||
|
if not self.trace_flow.check_index_compute(
|
||||||
|
start_idx, end_dim, end_node, end_idx
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
# flow search
|
||||||
|
chunk_info = self.trace_flow.flow_search(
|
||||||
|
start_idx, start_dim, end_idx, end_dim
|
||||||
|
)
|
||||||
|
if chunk_info is None:
|
||||||
|
continue
|
||||||
|
# check index copmute
|
||||||
|
if not self.trace_flow.check_index_duplicate(chunk_info):
|
||||||
|
continue
|
||||||
|
chunk_infos.append(chunk_info)
|
||||||
|
return chunk_infos
|
||||||
|
|
||||||
|
def _search_possible_chunk_regions(
|
||||||
|
self, max_chunk_region: Tuple, peak_node: Node
|
||||||
|
) -> List:
|
||||||
|
"""
|
||||||
|
Search every possible region within the max chunk region.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_chunk_region (Tuple)
|
||||||
|
peak_node (Node): peak memory node
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
possible_chunk_region (List)
|
||||||
|
"""
|
||||||
|
possible_chunk_region = []
|
||||||
|
output_trace = copy.deepcopy(self.trace_indice.indice_trace_list)
|
||||||
|
input_trace = [] # trace of a node's input nodes
|
||||||
|
for _, n in enumerate(self.trace_indice.node_list):
|
||||||
|
cur_trace = {}
|
||||||
|
for arg in n.args:
|
||||||
|
if type(arg) == type(n) and not is_non_compute_node_except_placeholder(
|
||||||
|
arg
|
||||||
|
):
|
||||||
|
cur_trace[arg] = self.trace_indice._find_trace_from_node(arg)
|
||||||
|
input_trace.append(cur_trace)
|
||||||
|
|
||||||
|
for start_idx in range(max_chunk_region[0], peak_node + 1):
|
||||||
|
for end_idx in range(peak_node, max_chunk_region[1] + 1):
|
||||||
|
# skip non compute nodes
|
||||||
|
if is_non_compute_node(
|
||||||
|
self.trace_indice.node_list[start_idx]
|
||||||
|
) or is_non_compute_node(self.trace_indice.node_list[end_idx]):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# select free dim
|
||||||
|
chunk_info = self._find_chunk_info(
|
||||||
|
input_trace, output_trace, start_idx, end_idx
|
||||||
|
)
|
||||||
|
if len(chunk_info) > 0:
|
||||||
|
possible_chunk_region.extend(chunk_info)
|
||||||
|
return possible_chunk_region
|
||||||
|
|
||||||
|
def _step_search(
|
||||||
|
self,
|
||||||
|
mem_peak: List[float],
|
||||||
|
active_node: List[List[Node]],
|
||||||
|
chunk_infos: List[Dict],
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
Find one chunk region
|
||||||
|
|
||||||
|
The chunk search is as follows:
|
||||||
|
1. find the peak memory node
|
||||||
|
2. find the max chunk region according to the peak memory node
|
||||||
|
3. find all possible chunk regions in the max chunk region
|
||||||
|
4. find the best chunk region for current status
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mem_peak (List): peak memory for every node
|
||||||
|
active_node (List[List[Node]]): active node for every node
|
||||||
|
chunk_infos (List[Dict]): all chunk info
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
best_chunk_region (Dict)
|
||||||
|
"""
|
||||||
|
peak_node = self._find_peak_node(mem_peak)
|
||||||
|
max_chunk_region = self._search_max_chunk_region(
|
||||||
|
active_node, peak_node, chunk_infos
|
||||||
|
)
|
||||||
|
if max_chunk_region == None:
|
||||||
|
return None
|
||||||
|
possible_chunk_regions = self._search_possible_chunk_regions(
|
||||||
|
max_chunk_region, peak_node
|
||||||
|
)
|
||||||
|
best_chunk_region = self.select_chunk._select_best_chunk_region(
|
||||||
|
possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
||||||
|
)
|
||||||
|
best_chunk_region = self.reorder_graph.reorder_all(best_chunk_region)
|
||||||
|
return best_chunk_region
|
||||||
|
|
||||||
|
def _stop_search(self, init_mem_peak, mem_peak):
|
||||||
|
sorted_init_mem_peak = sorted(init_mem_peak)
|
||||||
|
if max(mem_peak) < sorted_init_mem_peak[int(len(sorted_init_mem_peak) * 0.5)]:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def search_region(self) -> Dict:
|
||||||
|
"""
|
||||||
|
Search all chunk regions:
|
||||||
|
1. Estimate current memory
|
||||||
|
2. Find best chunk for current memory
|
||||||
|
3. goto 1
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
chunk_infos (Dict)
|
||||||
|
"""
|
||||||
|
chunk_infos = []
|
||||||
|
(
|
||||||
|
init_mem_peak,
|
||||||
|
_,
|
||||||
|
active_node,
|
||||||
|
) = self.estimate_memory.estimate_chunk_inference_mem(
|
||||||
|
self.trace_indice.node_list
|
||||||
|
)
|
||||||
|
mem_peak = init_mem_peak
|
||||||
|
|
||||||
|
while True:
|
||||||
|
chunk_info = self._step_search(mem_peak, active_node, chunk_infos)
|
||||||
|
if chunk_info is None:
|
||||||
|
break
|
||||||
|
chunk_infos.append(chunk_info)
|
||||||
|
|
||||||
|
(
|
||||||
|
mem_peak,
|
||||||
|
_,
|
||||||
|
active_node,
|
||||||
|
) = self.estimate_memory.estimate_chunk_inference_mem(
|
||||||
|
self.trace_indice.node_list, chunk_infos
|
||||||
|
)
|
||||||
|
if self._stop_search(init_mem_peak, mem_peak):
|
||||||
|
break
|
||||||
|
if self.print_mem:
|
||||||
|
self.print_mem = False
|
||||||
|
self.estimate_memory.estimate_chunk_inference_mem(
|
||||||
|
self.trace_indice.node_list, chunk_infos, print_mem=True
|
||||||
|
)
|
||||||
|
return chunk_infos
|
|
@ -0,0 +1,224 @@
|
||||||
|
from .estimate_memory import EstimateMemory
|
||||||
|
from .reorder_graph import ReorderGraph
|
||||||
|
from .trace_indice import TraceIndice
|
||||||
|
from .utils import is_non_compute_node
|
||||||
|
|
||||||
|
|
||||||
|
class SelectChunk(object):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
trace_indice: TraceIndice,
|
||||||
|
estimate_memory: EstimateMemory,
|
||||||
|
reorder_graph: ReorderGraph,
|
||||||
|
max_memory=None,
|
||||||
|
):
|
||||||
|
self.trace_indice = trace_indice
|
||||||
|
self.estimate_memory = estimate_memory
|
||||||
|
self.reorder_graph = reorder_graph
|
||||||
|
if max_memory is not None:
|
||||||
|
self.stratge = "fit_memory"
|
||||||
|
self.max_memory = max_memory # MB
|
||||||
|
else:
|
||||||
|
self.stratge = "min_memory"
|
||||||
|
|
||||||
|
def _select_best_chunk_region(
|
||||||
|
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
||||||
|
):
|
||||||
|
if self.stratge == "min_memory":
|
||||||
|
best_region = self._select_min_memory_chunk_region(
|
||||||
|
possible_chunk_regions,
|
||||||
|
chunk_infos,
|
||||||
|
peak_node,
|
||||||
|
max_chunk_region,
|
||||||
|
mem_peak,
|
||||||
|
)
|
||||||
|
elif self.stratge == "fit_memory":
|
||||||
|
best_region = self._select_fit_memory_chunk_region(
|
||||||
|
possible_chunk_regions,
|
||||||
|
chunk_infos,
|
||||||
|
peak_node,
|
||||||
|
max_chunk_region,
|
||||||
|
mem_peak,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise RuntimeError()
|
||||||
|
return best_region
|
||||||
|
|
||||||
|
def _select_fit_memory_chunk_region(
|
||||||
|
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
||||||
|
):
|
||||||
|
# stop chunk if max memory satisfy memory limit
|
||||||
|
if max(mem_peak) < self.max_memory:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# remove illegal regions
|
||||||
|
illegal_regions = []
|
||||||
|
for i in possible_chunk_regions:
|
||||||
|
if not self._is_legal_region(i, chunk_infos):
|
||||||
|
illegal_regions.append(i)
|
||||||
|
for i in illegal_regions:
|
||||||
|
if i in possible_chunk_regions:
|
||||||
|
possible_chunk_regions.remove(i)
|
||||||
|
|
||||||
|
if len(possible_chunk_regions) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# get mem for chunk region
|
||||||
|
regions_dict = []
|
||||||
|
for region in possible_chunk_regions:
|
||||||
|
cur_region = region.copy()
|
||||||
|
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
|
||||||
|
self.trace_indice.node_list, cur_region
|
||||||
|
)
|
||||||
|
cur_chunk_infos = chunk_infos + [cur_region]
|
||||||
|
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||||
|
cur_node_list, cur_chunk_infos
|
||||||
|
)[0]
|
||||||
|
cur_chunk_region_peak = cur_mem_peak[
|
||||||
|
max_chunk_region[0] : max_chunk_region[1] + 1
|
||||||
|
]
|
||||||
|
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
|
||||||
|
if cur_chunk_region_max_peak < self.max_memory:
|
||||||
|
regions_dict.append(
|
||||||
|
{
|
||||||
|
"chunk_info": region,
|
||||||
|
"chunk_max_mem": cur_chunk_region_max_peak,
|
||||||
|
"chunk_len": self._get_compute_node_num(
|
||||||
|
region["region"][0], region["region"][1]
|
||||||
|
),
|
||||||
|
"reorder_chunk_info": cur_region,
|
||||||
|
"reorder_node_list": cur_node_list,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# no region found
|
||||||
|
if len(regions_dict) == 0:
|
||||||
|
raise RuntimeError("Search failed. Try a larger memory threshold.")
|
||||||
|
|
||||||
|
# select the min chunk len
|
||||||
|
chunk_len = [i["chunk_len"] for i in regions_dict]
|
||||||
|
best_region_idx = chunk_len.index(min(chunk_len))
|
||||||
|
best_region = regions_dict[best_region_idx]
|
||||||
|
|
||||||
|
# get max chunk size
|
||||||
|
best_region = self._get_fit_chunk_size(best_region, chunk_infos)
|
||||||
|
return best_region
|
||||||
|
|
||||||
|
def _get_fit_chunk_size(self, chunk_region_dict, chunk_infos):
|
||||||
|
chunk_size = 1
|
||||||
|
reorder_chunk_info = chunk_region_dict["reorder_chunk_info"]
|
||||||
|
reorder_chunk_info["chunk_size"] = chunk_size
|
||||||
|
cur_chunk_max_mem = 0
|
||||||
|
# search a region
|
||||||
|
while cur_chunk_max_mem < self.max_memory:
|
||||||
|
chunk_size *= 2
|
||||||
|
reorder_chunk_info["chunk_size"] = chunk_size
|
||||||
|
cur_chunk_infos = chunk_infos + [reorder_chunk_info]
|
||||||
|
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||||
|
chunk_region_dict["reorder_node_list"], cur_chunk_infos
|
||||||
|
)[0]
|
||||||
|
cur_chunk_max_mem = max(
|
||||||
|
cur_mem_peak[
|
||||||
|
reorder_chunk_info["region"][0] : reorder_chunk_info["region"][1]
|
||||||
|
+ 1
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# search exact size
|
||||||
|
chunk_info = chunk_region_dict["chunk_info"]
|
||||||
|
chunk_info["chunk_size"] = self._chunk_size_binary_search(
|
||||||
|
chunk_size // 2, chunk_size, chunk_region_dict, chunk_infos
|
||||||
|
)
|
||||||
|
return chunk_info
|
||||||
|
|
||||||
|
def _chunk_size_binary_search(self, left, right, chunk_region_dict, chunk_infos):
|
||||||
|
if left >= 16:
|
||||||
|
gap = 4
|
||||||
|
else:
|
||||||
|
gap = 1
|
||||||
|
chunk_info = chunk_region_dict["reorder_chunk_info"]
|
||||||
|
while right >= left + gap:
|
||||||
|
mid = int((left + right) / 2 + 0.5)
|
||||||
|
chunk_info["chunk_size"] = mid
|
||||||
|
cur_chunk_infos = chunk_infos + [chunk_info]
|
||||||
|
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||||
|
chunk_region_dict["reorder_node_list"], cur_chunk_infos
|
||||||
|
)[0]
|
||||||
|
cur_chunk_max_mem = max(
|
||||||
|
cur_mem_peak[chunk_info["region"][0] : chunk_info["region"][1] + 1]
|
||||||
|
)
|
||||||
|
if cur_chunk_max_mem >= self.max_memory:
|
||||||
|
right = mid - gap
|
||||||
|
else:
|
||||||
|
left = mid + gap
|
||||||
|
return left
|
||||||
|
|
||||||
|
def _get_compute_node_num(self, start, end):
|
||||||
|
count = 0
|
||||||
|
for i in self.trace_indice.node_list[start : end + 1]:
|
||||||
|
if not is_non_compute_node(i):
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
def _select_min_memory_chunk_region(
|
||||||
|
self, possible_chunk_regions, chunk_infos, peak_node, max_chunk_region, mem_peak
|
||||||
|
):
|
||||||
|
# remove illegal regions
|
||||||
|
illegal_regions = []
|
||||||
|
for i in possible_chunk_regions:
|
||||||
|
if not self._is_legal_region(i, chunk_infos):
|
||||||
|
illegal_regions.append(i)
|
||||||
|
for i in illegal_regions:
|
||||||
|
if i in possible_chunk_regions:
|
||||||
|
possible_chunk_regions.remove(i)
|
||||||
|
|
||||||
|
if len(possible_chunk_regions) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# get mem for chunk region
|
||||||
|
regions_dict = []
|
||||||
|
for region in possible_chunk_regions:
|
||||||
|
cur_region = region.copy()
|
||||||
|
cur_node_list, cur_region = self.reorder_graph.tmp_reorder(
|
||||||
|
self.trace_indice.node_list, cur_region
|
||||||
|
)
|
||||||
|
cur_chunk_infos = chunk_infos + [cur_region]
|
||||||
|
cur_mem_peak = self.estimate_memory.estimate_chunk_inference_mem(
|
||||||
|
cur_node_list, cur_chunk_infos
|
||||||
|
)[0]
|
||||||
|
cur_chunk_region_peak = cur_mem_peak[
|
||||||
|
max_chunk_region[0] : max_chunk_region[1] + 1
|
||||||
|
]
|
||||||
|
cur_chunk_region_max_peak = max(cur_chunk_region_peak)
|
||||||
|
regions_dict.append(
|
||||||
|
{
|
||||||
|
"chunk_info": region,
|
||||||
|
"chunk_max_mem": cur_chunk_region_max_peak,
|
||||||
|
"chunk_len": self._get_compute_node_num(
|
||||||
|
region["region"][0], region["region"][1]
|
||||||
|
),
|
||||||
|
"reorder_chunk_info": cur_region,
|
||||||
|
"reorder_node_list": cur_node_list,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# select the min mem
|
||||||
|
chunk_max_mem = [i["chunk_max_mem"] for i in regions_dict]
|
||||||
|
best_region_idx = chunk_max_mem.index(min(chunk_max_mem))
|
||||||
|
best_region = regions_dict[best_region_idx]["chunk_info"]
|
||||||
|
if best_region is not None:
|
||||||
|
best_region["chunk_size"] = 1
|
||||||
|
return best_region
|
||||||
|
|
||||||
|
def _is_legal_region(self, cur_chunk_info, chunk_infos):
|
||||||
|
(chunk_region_start, chunk_region_end) = cur_chunk_info["region"]
|
||||||
|
if cur_chunk_info in chunk_infos:
|
||||||
|
return False
|
||||||
|
if chunk_region_end < chunk_region_start:
|
||||||
|
return False
|
||||||
|
for i in chunk_infos:
|
||||||
|
region = i["region"]
|
||||||
|
if not (
|
||||||
|
(chunk_region_start > region[1] and chunk_region_end > region[1])
|
||||||
|
or (chunk_region_start < region[0] and chunk_region_end < region[0])
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
return True
|
|
@ -0,0 +1,420 @@
|
||||||
|
from .trace_indice import TraceIndice
|
||||||
|
from .utils import (
|
||||||
|
find_chunk_all_input_nodes,
|
||||||
|
find_chunk_compute_input_and_output_nodes,
|
||||||
|
find_idx_by_name,
|
||||||
|
get_node_shape,
|
||||||
|
is_non_compute_node,
|
||||||
|
is_non_compute_node_except_placeholder,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TraceFlow(object):
|
||||||
|
def __init__(self, trace_indice: TraceIndice) -> None:
|
||||||
|
self.trace_indice = trace_indice
|
||||||
|
|
||||||
|
def check_index_source(self, start_dim, start_node, start_idx, end_dim, end_node):
|
||||||
|
"""
|
||||||
|
Check 2 given index: one index should be source of the other
|
||||||
|
Args:
|
||||||
|
start_idx(int): start node chunk dim
|
||||||
|
start_node(node): start node
|
||||||
|
end_idx(int): end node chunk dim
|
||||||
|
end_node(node): end node
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if check pass
|
||||||
|
"""
|
||||||
|
start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list)
|
||||||
|
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
|
||||||
|
end_node_trace_source = end_node_trace["source"][end_dim]
|
||||||
|
sorted_source = sorted(
|
||||||
|
end_node_trace_source.items(), key=lambda d: d[0], reverse=True
|
||||||
|
)
|
||||||
|
for node_idx, node_dim in sorted_source:
|
||||||
|
if node_idx == start_node_idx and start_dim in node_dim:
|
||||||
|
return True
|
||||||
|
# it means we meet a node outside the loop, and the node is not input node
|
||||||
|
if node_idx < start_idx:
|
||||||
|
return False
|
||||||
|
return False
|
||||||
|
|
||||||
|
def check_index_compute(self, start_idx, end_dim, end_node, end_idx):
|
||||||
|
"""
|
||||||
|
Check 2 given index: check they haven't been computed in the source trace.
|
||||||
|
Args:
|
||||||
|
start_idx(int): start node chunk dim
|
||||||
|
start_node(node): start node
|
||||||
|
end_idx(int): end node chunk dim
|
||||||
|
end_node(node): end node
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if check pass
|
||||||
|
"""
|
||||||
|
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
|
||||||
|
end_node_compute = end_node_trace["compute"][end_dim]
|
||||||
|
if any(start_idx <= i <= end_idx for i in end_node_compute):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_node_chunk_dim(self, node_from, node_from_dim, node_to):
|
||||||
|
node_from_source = self.trace_indice._find_source_trace_from_node(node_from)
|
||||||
|
dim_source = node_from_source[node_from_dim]
|
||||||
|
node_to_idx = find_idx_by_name(node_to.name, self.trace_indice.node_list)
|
||||||
|
for k, v in dim_source.items():
|
||||||
|
if k == node_to_idx:
|
||||||
|
return v
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _find_inherit_dim(self, input_node, input_dim, node):
|
||||||
|
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
|
||||||
|
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
|
||||||
|
for node_dim in range(len(get_node_shape(node))):
|
||||||
|
if (
|
||||||
|
input_node_idx in node_trace_source[node_dim]
|
||||||
|
and input_dim[0] in node_trace_source[node_dim][input_node_idx]
|
||||||
|
):
|
||||||
|
return node_dim
|
||||||
|
return None
|
||||||
|
|
||||||
|
def check_index_duplicate(self, chunk_infos, return_dim=False):
|
||||||
|
input_dim_after_node = {}
|
||||||
|
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
|
||||||
|
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
|
||||||
|
inherit_dim = self._find_inherit_dim(
|
||||||
|
input_node, v, self.trace_indice.node_list[k]
|
||||||
|
)
|
||||||
|
if inherit_dim:
|
||||||
|
input_dim_after_node[k] = inherit_dim
|
||||||
|
|
||||||
|
for node in self.trace_indice.node_list[
|
||||||
|
chunk_infos["region"][0] : chunk_infos["region"][1] + 1
|
||||||
|
]:
|
||||||
|
if is_non_compute_node_except_placeholder(node):
|
||||||
|
continue
|
||||||
|
count = 0
|
||||||
|
duplicate_dims = []
|
||||||
|
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
|
||||||
|
for node_dim in range(len(get_node_shape(node))):
|
||||||
|
duplicate_dim = []
|
||||||
|
duplicate_flag = False
|
||||||
|
dim_source = node_trace_source[node_dim]
|
||||||
|
for k, v in dim_source.items():
|
||||||
|
if chunk_infos["region"][0] <= k <= chunk_infos["region"][1]:
|
||||||
|
if k in input_dim_after_node and input_dim_after_node[k] in v:
|
||||||
|
duplicate_flag = True
|
||||||
|
duplicate_dim.append((k, v))
|
||||||
|
duplicate_dims.append(duplicate_dim)
|
||||||
|
if duplicate_flag:
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
if count > 1:
|
||||||
|
if return_dim:
|
||||||
|
return False, duplicate_dims
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
if return_dim:
|
||||||
|
return True, None
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _assgin_single_node_flow(
|
||||||
|
self,
|
||||||
|
arg_node,
|
||||||
|
start_idx,
|
||||||
|
end_idx,
|
||||||
|
cur_node_dim,
|
||||||
|
cur_node_compute,
|
||||||
|
cur_node_source,
|
||||||
|
cur_node_fix_dim,
|
||||||
|
all_node_info,
|
||||||
|
next_node_list,
|
||||||
|
):
|
||||||
|
arg_idx = find_idx_by_name(arg_node.name, self.trace_indice.node_list)
|
||||||
|
# arg in chunk range or be inputs
|
||||||
|
if not (start_idx <= arg_idx < end_idx):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# find arg dim
|
||||||
|
if cur_node_dim is not None:
|
||||||
|
# dim is computed
|
||||||
|
if arg_idx in cur_node_compute[cur_node_dim]:
|
||||||
|
return False
|
||||||
|
if arg_idx not in cur_node_source[cur_node_dim]:
|
||||||
|
arg_dim = None
|
||||||
|
else:
|
||||||
|
arg_dim = cur_node_source[cur_node_dim][arg_idx][0]
|
||||||
|
else:
|
||||||
|
arg_dim = None
|
||||||
|
|
||||||
|
# get fix dim
|
||||||
|
arg_fix_dim = []
|
||||||
|
if cur_node_dim is not None:
|
||||||
|
for i in cur_node_fix_dim:
|
||||||
|
fix_dim_source = cur_node_source[i]
|
||||||
|
if arg_idx in fix_dim_source:
|
||||||
|
arg_fix_dim.append(fix_dim_source[arg_idx][0])
|
||||||
|
|
||||||
|
# if already in node_info, arg dim must be same
|
||||||
|
if arg_node in all_node_info:
|
||||||
|
if all_node_info[arg_node]["chunk_dim"] != arg_dim:
|
||||||
|
return False
|
||||||
|
all_node_info[arg_node]["fix_dim"] = list(
|
||||||
|
set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)
|
||||||
|
)
|
||||||
|
# else add it to list
|
||||||
|
else:
|
||||||
|
all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim}
|
||||||
|
|
||||||
|
next_node_list.append(arg_node)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _get_all_node_info(self, end_dim, start_idx, end_idx):
|
||||||
|
cur_node_list = [
|
||||||
|
self.trace_indice.node_list[end_idx]
|
||||||
|
] # start from the last node
|
||||||
|
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
||||||
|
|
||||||
|
while len(cur_node_list) > 0:
|
||||||
|
next_node_list = []
|
||||||
|
|
||||||
|
for cur_node in cur_node_list:
|
||||||
|
# get cur node info
|
||||||
|
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
|
||||||
|
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
||||||
|
if cur_node_chunk_dim:
|
||||||
|
cur_node_compute = self.trace_indice._find_compute_trace_from_node(
|
||||||
|
cur_node
|
||||||
|
)
|
||||||
|
cur_node_source = self.trace_indice._find_source_trace_from_node(
|
||||||
|
cur_node
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cur_node_compute = cur_node_source = None
|
||||||
|
|
||||||
|
# get all valid args
|
||||||
|
arg_list = []
|
||||||
|
for arg in cur_node.args:
|
||||||
|
if type(arg) != type(cur_node):
|
||||||
|
continue
|
||||||
|
if is_non_compute_node(arg):
|
||||||
|
continue
|
||||||
|
arg_list.append(arg)
|
||||||
|
flow_flag = self._assgin_single_node_flow(
|
||||||
|
arg,
|
||||||
|
start_idx,
|
||||||
|
end_idx,
|
||||||
|
cur_node_chunk_dim,
|
||||||
|
cur_node_compute,
|
||||||
|
cur_node_source,
|
||||||
|
cur_node_fix_dim,
|
||||||
|
all_node_info,
|
||||||
|
next_node_list,
|
||||||
|
)
|
||||||
|
if flow_flag == False:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(arg_list) == 2:
|
||||||
|
if any(i in cur_node.name for i in ["add", "mul"]):
|
||||||
|
for arg in arg_list:
|
||||||
|
if not (
|
||||||
|
start_idx
|
||||||
|
<= find_idx_by_name(
|
||||||
|
arg.name, self.trace_indice.node_list
|
||||||
|
)
|
||||||
|
< end_idx
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
|
||||||
|
arg_fix_dim = all_node_info[arg]["fix_dim"]
|
||||||
|
arg_shape = get_node_shape(arg)
|
||||||
|
# add all dim as fix dim except chunk dim
|
||||||
|
for i, shape in enumerate(arg_shape):
|
||||||
|
if shape != 1 and i != cur_node_chunk_dim:
|
||||||
|
if i == arg_chunk_dim:
|
||||||
|
return None
|
||||||
|
if i not in arg_fix_dim:
|
||||||
|
arg_fix_dim.append(i)
|
||||||
|
elif "einsum" in cur_node.name:
|
||||||
|
pass
|
||||||
|
elif "matmul" in cur_node.name:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
cur_node_list = next_node_list
|
||||||
|
return all_node_info
|
||||||
|
|
||||||
|
def _get_input_nodes_dim(self, inputs, start_idx, end_idx, all_node_info):
|
||||||
|
inputs_dim = []
|
||||||
|
remove_inputs = []
|
||||||
|
for input_node in inputs:
|
||||||
|
input_dict = {}
|
||||||
|
input_node_idx = find_idx_by_name(
|
||||||
|
input_node.name, self.trace_indice.node_list
|
||||||
|
)
|
||||||
|
for user in input_node.users.keys():
|
||||||
|
if is_non_compute_node(user):
|
||||||
|
continue
|
||||||
|
user_idx = find_idx_by_name(user.name, self.trace_indice.node_list)
|
||||||
|
if start_idx <= user_idx <= end_idx:
|
||||||
|
chunk_dim = all_node_info[user]["chunk_dim"]
|
||||||
|
if chunk_dim is not None:
|
||||||
|
user_source = self.trace_indice._find_source_trace_from_node(
|
||||||
|
user
|
||||||
|
)[chunk_dim]
|
||||||
|
if input_node_idx in user_source:
|
||||||
|
input_dict[user_idx] = user_source[input_node_idx]
|
||||||
|
else:
|
||||||
|
return None, None
|
||||||
|
if len(input_dict) == 0:
|
||||||
|
remove_inputs.append(input_node)
|
||||||
|
else:
|
||||||
|
inputs_dim.append(input_dict)
|
||||||
|
for i in remove_inputs:
|
||||||
|
if i in inputs:
|
||||||
|
inputs.remove(i)
|
||||||
|
return inputs, inputs_dim
|
||||||
|
|
||||||
|
def _get_prepose_nodes(self, all_node_info, start_idx, end_idx):
|
||||||
|
# get all possible prepose nodes
|
||||||
|
maybe_prepose_nodes = []
|
||||||
|
for node, node_info in all_node_info.items():
|
||||||
|
if node_info["chunk_dim"] is None:
|
||||||
|
maybe_prepose_nodes.append(node)
|
||||||
|
maybe_prepose_nodes.sort(
|
||||||
|
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list),
|
||||||
|
reverse=True,
|
||||||
|
) # from last node to first node
|
||||||
|
prepose_nodes = []
|
||||||
|
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
|
||||||
|
while len(maybe_prepose_nodes) > 0:
|
||||||
|
tmp_cur_prepose_nodes = [maybe_prepose_nodes[0]]
|
||||||
|
tmp_cur_related_prepose_nodes = []
|
||||||
|
prepose_flag = True
|
||||||
|
|
||||||
|
# loop cur node's all arg until out of chunk
|
||||||
|
while len(tmp_cur_prepose_nodes) > 0:
|
||||||
|
if prepose_flag == False:
|
||||||
|
break
|
||||||
|
tmp_next_prepose_nodes = []
|
||||||
|
tmp_cur_related_prepose_nodes.extend(tmp_cur_prepose_nodes)
|
||||||
|
for cur_prepose_node in tmp_cur_prepose_nodes:
|
||||||
|
if prepose_flag == False:
|
||||||
|
break
|
||||||
|
for cur_prepose_node_arg in cur_prepose_node.args:
|
||||||
|
if type(cur_prepose_node_arg) != type(cur_prepose_node):
|
||||||
|
continue
|
||||||
|
# out of loop
|
||||||
|
if not (
|
||||||
|
start_idx
|
||||||
|
<= find_idx_by_name(
|
||||||
|
cur_prepose_node_arg.name, self.trace_indice.node_list
|
||||||
|
)
|
||||||
|
< end_idx
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
# compute op in loop
|
||||||
|
elif cur_prepose_node_arg in all_node_info:
|
||||||
|
if all_node_info[cur_prepose_node_arg]["chunk_dim"] is None:
|
||||||
|
tmp_next_prepose_nodes.append(cur_prepose_node_arg)
|
||||||
|
else:
|
||||||
|
prepose_flag = False
|
||||||
|
break
|
||||||
|
# non compute op
|
||||||
|
else:
|
||||||
|
tmp_next_prepose_nodes.append(cur_prepose_node_arg)
|
||||||
|
tmp_cur_prepose_nodes = tmp_next_prepose_nodes
|
||||||
|
|
||||||
|
if prepose_flag == False:
|
||||||
|
maybe_prepose_nodes.remove(maybe_prepose_nodes[0])
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
for n in tmp_cur_related_prepose_nodes:
|
||||||
|
if n not in prepose_nodes:
|
||||||
|
prepose_nodes.append(n)
|
||||||
|
if n in maybe_prepose_nodes:
|
||||||
|
maybe_prepose_nodes.remove(n)
|
||||||
|
# sort by index
|
||||||
|
prepose_nodes.sort(
|
||||||
|
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list)
|
||||||
|
)
|
||||||
|
|
||||||
|
return prepose_nodes
|
||||||
|
|
||||||
|
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
|
||||||
|
# we need to log input nodes to avoid deleteing them in the loop
|
||||||
|
chunk_node_list = self.trace_indice.node_list[start_idx : end_idx + 1]
|
||||||
|
# also need to get some prepose node's arg out of non_chunk_inputs
|
||||||
|
for n in chunk_info["args"]["prepose_nodes"]:
|
||||||
|
chunk_node_list.remove(n)
|
||||||
|
non_chunk_inputs = find_chunk_all_input_nodes(chunk_node_list)
|
||||||
|
for i in non_chunk_inputs:
|
||||||
|
if i not in chunk_info["inputs"]:
|
||||||
|
chunk_info["inputs_non_chunk"].append(i)
|
||||||
|
return chunk_info
|
||||||
|
|
||||||
|
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
||||||
|
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
||||||
|
self.trace_indice.node_list[start_idx : end_idx + 1]
|
||||||
|
)
|
||||||
|
# only single ouput
|
||||||
|
if len(outputs) > 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# get every node's chunk dim and fix dim
|
||||||
|
all_node_info = self._get_all_node_info(end_dim, start_idx, end_idx)
|
||||||
|
if all_node_info is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# get input nodes' chunk dim
|
||||||
|
inputs, inputs_dim = self._get_input_nodes_dim(
|
||||||
|
inputs, start_idx, end_idx, all_node_info
|
||||||
|
)
|
||||||
|
if inputs is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
chunk_info = {
|
||||||
|
"region": (start_idx, end_idx),
|
||||||
|
"inputs": inputs,
|
||||||
|
"inputs_non_chunk": [],
|
||||||
|
"inputs_dim": inputs_dim,
|
||||||
|
"outputs": outputs,
|
||||||
|
"outputs_dim": end_dim,
|
||||||
|
"node_chunk_dim": all_node_info,
|
||||||
|
"args": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
# move useless nodes ahead of loop
|
||||||
|
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(
|
||||||
|
all_node_info, start_idx, end_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
# find non chunk inputs
|
||||||
|
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
|
||||||
|
|
||||||
|
# reassgin reshape size, some size may have changed due to chunk
|
||||||
|
chunk_info = self._reassgin_reshape_size(chunk_info)
|
||||||
|
|
||||||
|
return chunk_info
|
||||||
|
|
||||||
|
def _reassgin_reshape_size(self, chunk_info):
|
||||||
|
chunk_region = chunk_info["region"]
|
||||||
|
reshape_size = {}
|
||||||
|
chunk_shape = get_node_shape(chunk_info["outputs"][0])[
|
||||||
|
chunk_info["outputs_dim"]
|
||||||
|
]
|
||||||
|
for node in self.trace_indice.node_list[chunk_region[0] : chunk_region[1] + 1]:
|
||||||
|
if any(i in node.name for i in ["reshape", "view"]):
|
||||||
|
reshape_args = node.args[1:]
|
||||||
|
reshape_log = self.trace_indice.indice_view_list[node]
|
||||||
|
chunk_dim = chunk_info["node_chunk_dim"][node]["chunk_dim"]
|
||||||
|
reshape_size[node.name] = {}
|
||||||
|
for reshape_arg_dim, reshape_arg in enumerate(reshape_args):
|
||||||
|
if reshape_arg_dim in reshape_log["dim_to"]:
|
||||||
|
continue
|
||||||
|
if reshape_arg_dim == chunk_dim:
|
||||||
|
reshape_size[node.name][reshape_arg.name] = (
|
||||||
|
"min(chunk_size, %d - chunk_idx)" % chunk_shape
|
||||||
|
)
|
||||||
|
chunk_info["reshape_size"] = reshape_size
|
||||||
|
return chunk_info
|
|
@ -0,0 +1,559 @@
|
||||||
|
import copy
|
||||||
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
|
from torch.fx.node import Node
|
||||||
|
|
||||||
|
from .utils import find_idx_by_name, get_node_shape
|
||||||
|
|
||||||
|
|
||||||
|
class TraceIndice(object):
|
||||||
|
"""
|
||||||
|
Trace all indice infomation for every node.
|
||||||
|
|
||||||
|
Indice is a logical concept. Equal dims can been treated as one indice.
|
||||||
|
eg. dim(x1) = [a, b, c]
|
||||||
|
dim(x2) = [d, e, f]
|
||||||
|
and we have x3 = x1 * x2.
|
||||||
|
then a=d, b=e, c=f, due to the broadcast property,
|
||||||
|
dim(x1)=dim(x2)=dim(x3)=[a, b, c]
|
||||||
|
This class will record every node's dims' indice, compute and source.
|
||||||
|
|
||||||
|
Attibutes:
|
||||||
|
node_list (List)
|
||||||
|
indice_trace_list (List): [{"indice": [...], "compute": [...], "source": [...]}, {...}]
|
||||||
|
indice_view_list (Dict): not used for now
|
||||||
|
indice_count (int): record indice number
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_list (List)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, node_list: List) -> None:
|
||||||
|
self.node_list = node_list
|
||||||
|
self.indice_trace_list = self._init_indice_trace_list()
|
||||||
|
self.indice_view_list = {}
|
||||||
|
self.indice_count = -1
|
||||||
|
|
||||||
|
def _init_indice_trace_list(self):
|
||||||
|
indice_trace_list = []
|
||||||
|
for n in self.node_list:
|
||||||
|
if get_node_shape(n) != None:
|
||||||
|
cur_trace = {
|
||||||
|
"indice": [None for _ in range(len(get_node_shape(n)))],
|
||||||
|
"compute": [[] for _ in range(len(get_node_shape(n)))],
|
||||||
|
"source": [{} for _ in range(len(get_node_shape(n)))],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
cur_trace = {"indice": [], "compute": [], "source": []}
|
||||||
|
indice_trace_list.append(cur_trace)
|
||||||
|
return indice_trace_list
|
||||||
|
|
||||||
|
def _add_indice(self):
|
||||||
|
"""
|
||||||
|
Update the count and return it. To record the idx number.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
indice_count: int
|
||||||
|
"""
|
||||||
|
self.indice_count += 1
|
||||||
|
return self.indice_count
|
||||||
|
|
||||||
|
def _del_dim(self, idx, dim_idx):
|
||||||
|
self.indice_trace_list[idx]["indice"].pop(dim_idx)
|
||||||
|
self.indice_trace_list[idx]["compute"].pop(dim_idx)
|
||||||
|
self.indice_trace_list[idx]["source"].pop(dim_idx)
|
||||||
|
|
||||||
|
def _add_dim(self, node_idx, dim_idx):
|
||||||
|
self.indice_trace_list[node_idx]["indice"].insert(dim_idx, self._add_indice())
|
||||||
|
self.indice_trace_list[node_idx]["compute"].insert(dim_idx, [])
|
||||||
|
self.indice_trace_list[node_idx]["source"].insert(dim_idx, {})
|
||||||
|
|
||||||
|
def _transform_indice(self, node, node_dim):
|
||||||
|
node_idx = self._find_indice_trace_from_node(node)
|
||||||
|
dims = list(range(len(node_idx)))
|
||||||
|
return dims[node_dim]
|
||||||
|
|
||||||
|
def _inherit_indice(self, node_from, node_from_dim, node_to, node_to_dim):
|
||||||
|
node_from_dim = self._transform_indice(node_from, node_from_dim)
|
||||||
|
node_to_dim = self._transform_indice(node_to, node_to_dim)
|
||||||
|
node_from_trace = self._find_trace_from_node(node_from)
|
||||||
|
node_to_trace = self._find_trace_from_node(node_to)
|
||||||
|
node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim]
|
||||||
|
node_to_trace["compute"][node_to_dim] = copy.deepcopy(
|
||||||
|
node_from_trace["compute"][node_from_dim]
|
||||||
|
)
|
||||||
|
self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True)
|
||||||
|
|
||||||
|
def _inherit_all_computation(self, node_from, node_to):
|
||||||
|
node_from_compute = self._find_compute_trace_from_node(node_from)
|
||||||
|
node_to_compute = self._find_compute_trace_from_node(node_to)
|
||||||
|
assert len(node_from_compute) == len(node_to_compute)
|
||||||
|
for i in range(len(node_from_compute)):
|
||||||
|
self._add_source(node_from, i, node_to, i)
|
||||||
|
node_to_compute[i] = copy.deepcopy(node_from_compute[i])
|
||||||
|
|
||||||
|
def _add_source(self, node_from, node_from_dim, node_to, node_to_dim, init=False):
|
||||||
|
node_from_dim = self._transform_indice(node_from, node_from_dim)
|
||||||
|
node_from_trace_source = self._find_source_trace_from_node(node_from)
|
||||||
|
node_to_dim = self._transform_indice(node_to, node_to_dim)
|
||||||
|
node_to_trace_source = self._find_source_trace_from_node(node_to)
|
||||||
|
node_from_idx = find_idx_by_name(node_from.name, self.node_list)
|
||||||
|
if init:
|
||||||
|
node_to_trace_source[node_to_dim] = {}
|
||||||
|
# add dim to cur new source
|
||||||
|
if node_from_idx not in node_to_trace_source[node_to_dim]:
|
||||||
|
node_to_trace_source[node_to_dim][node_from_idx] = [node_from_dim]
|
||||||
|
else:
|
||||||
|
if node_from_dim not in node_to_trace_source[node_to_dim][node_from_idx]:
|
||||||
|
node_to_trace_source[node_to_dim][node_from_idx].append(node_from_dim)
|
||||||
|
# update inputs source
|
||||||
|
for node_idx, node_dim in node_from_trace_source[node_from_dim].items():
|
||||||
|
if node_idx not in node_to_trace_source[node_to_dim]:
|
||||||
|
node_to_trace_source[node_to_dim][node_idx] = copy.deepcopy(node_dim)
|
||||||
|
else:
|
||||||
|
for d in node_dim:
|
||||||
|
if d not in node_to_trace_source[node_to_dim][node_idx]:
|
||||||
|
node_to_trace_source[node_to_dim][node_idx].append(d)
|
||||||
|
|
||||||
|
def _mark_computation_from_node(self, node_from, node_to, exclude=None):
|
||||||
|
if exclude == None:
|
||||||
|
exclude = []
|
||||||
|
else:
|
||||||
|
exclude = [self._transform_indice(node_to, i) for i in exclude]
|
||||||
|
node_from_compute = self._find_compute_trace_from_node(node_from)
|
||||||
|
node_to_compute = self._find_compute_trace_from_node(node_to)
|
||||||
|
# assert len(node_from_compute) == len(node_to_compute)
|
||||||
|
for i in range(-1, -min(len(node_from_compute), len(node_to_compute)) - 1, -1):
|
||||||
|
if self._transform_indice(node_to, i) in exclude:
|
||||||
|
continue
|
||||||
|
self._add_source(node_from, i, node_to, i)
|
||||||
|
for j in node_from_compute[i]:
|
||||||
|
if j not in node_to_compute[i]:
|
||||||
|
node_to_compute[i].append(j)
|
||||||
|
|
||||||
|
def _mark_computation(self, node, idx, dim):
|
||||||
|
"""
|
||||||
|
Mark some dims of node as computed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
idx (int): node index
|
||||||
|
dim (list or int): dims to be marked as computed
|
||||||
|
"""
|
||||||
|
if isinstance(dim, int):
|
||||||
|
dim = [dim]
|
||||||
|
dims = list(range(len(get_node_shape(node))))
|
||||||
|
for d in dim:
|
||||||
|
cur_dim = dims[d]
|
||||||
|
if idx not in self.indice_trace_list[idx]["compute"][cur_dim]:
|
||||||
|
self.indice_trace_list[idx]["compute"][cur_dim].append(idx)
|
||||||
|
|
||||||
|
def _find_trace_from_node(self, node):
|
||||||
|
"""
|
||||||
|
Find node idx and compute trace by the node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
Returns:
|
||||||
|
idx (list): idx of the node
|
||||||
|
compute (list): computed idx of the node.
|
||||||
|
"""
|
||||||
|
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||||
|
node_dict = self.indice_trace_list[node_idx]
|
||||||
|
return node_dict
|
||||||
|
|
||||||
|
def _find_source_trace_from_node(self, node):
|
||||||
|
"""
|
||||||
|
Find node source trace by the node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
Returns:
|
||||||
|
idx (list): idx of the node
|
||||||
|
compute (list): computed idx of the node.
|
||||||
|
"""
|
||||||
|
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||||
|
node_dict = self.indice_trace_list[node_idx]
|
||||||
|
return node_dict["source"]
|
||||||
|
|
||||||
|
def _find_indice_trace_from_node(self, node):
|
||||||
|
"""
|
||||||
|
Find node idx trace by the node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
Returns:
|
||||||
|
idx (list): idx of the node
|
||||||
|
"""
|
||||||
|
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||||
|
return self.indice_trace_list[node_idx]["indice"]
|
||||||
|
|
||||||
|
def _find_compute_trace_from_node(self, node):
|
||||||
|
"""
|
||||||
|
Find node compute trace by the node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
Returns:
|
||||||
|
compute (list): computed idx of the node.
|
||||||
|
"""
|
||||||
|
node_idx = find_idx_by_name(node.name, self.node_list)
|
||||||
|
return self.indice_trace_list[node_idx]["compute"]
|
||||||
|
|
||||||
|
def _assign_indice_as_input(self, node, node_idx, input_node=None):
|
||||||
|
"""
|
||||||
|
Assign node's trace as its input node.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
if input_node == None:
|
||||||
|
input_node = node.args[0]
|
||||||
|
input_node_idx = find_idx_by_name(input_node.name, self.node_list)
|
||||||
|
input_node_idx_trace = self.indice_trace_list[input_node_idx]["indice"]
|
||||||
|
|
||||||
|
new_idx_trace = copy.deepcopy(input_node_idx_trace)
|
||||||
|
self.indice_trace_list[node_idx]["indice"] = new_idx_trace
|
||||||
|
|
||||||
|
self._inherit_all_computation(input_node, node)
|
||||||
|
|
||||||
|
def _assign_all_indice(self, node, node_idx):
|
||||||
|
"""
|
||||||
|
Add new indice for all node's dims.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
shape = node.meta["tensor_meta"].shape
|
||||||
|
new_trace = []
|
||||||
|
for _ in shape:
|
||||||
|
new_trace.append(self._add_indice())
|
||||||
|
self.indice_trace_list[node_idx]["indice"] = new_trace
|
||||||
|
|
||||||
|
def _assign_transpose_indice(self, node, node_idx):
|
||||||
|
"""
|
||||||
|
Assign indice for transpose op.
|
||||||
|
1. swap input's dim according to transpose args
|
||||||
|
2. inherit input's computation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
input_node = node.args[0]
|
||||||
|
tranpose_dim = node.args[1:]
|
||||||
|
|
||||||
|
self._assign_indice_as_input(node, node_idx, input_node)
|
||||||
|
self._inherit_indice(input_node, tranpose_dim[1], node, tranpose_dim[0])
|
||||||
|
self._inherit_indice(input_node, tranpose_dim[0], node, tranpose_dim[1])
|
||||||
|
|
||||||
|
def _assign_permute_indice(self, node, node_idx):
|
||||||
|
"""
|
||||||
|
Assign indice for permute op.
|
||||||
|
1. swap input's dim according to permute args
|
||||||
|
2. inherit input's computation
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
permute_dim = node.args[1:]
|
||||||
|
input_node = node.args[0]
|
||||||
|
|
||||||
|
self._assign_indice_as_input(node, node_idx, input_node)
|
||||||
|
for idx, d in enumerate(permute_dim):
|
||||||
|
self._inherit_indice(input_node, d, node, idx)
|
||||||
|
|
||||||
|
def _assign_linear_indice(self, node, node_idx):
|
||||||
|
"""
|
||||||
|
Assign indice for linear op.
|
||||||
|
1. copy trace from input node and change last indice accroding to weight
|
||||||
|
2. mark equal for input node last indice, weight first dim and bias dim.
|
||||||
|
3. inherit input's computation, mark computation for last dim.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
if len(node.args) == 2:
|
||||||
|
_, weight = node.args
|
||||||
|
else:
|
||||||
|
_, weight, _ = node.args
|
||||||
|
|
||||||
|
self._assign_indice_as_input(node, node_idx)
|
||||||
|
self._inherit_indice(weight, 1, node, -1)
|
||||||
|
|
||||||
|
self._mark_computation(node, node_idx, [-1])
|
||||||
|
|
||||||
|
def _assign_matmul_indice(self, node, node_idx):
|
||||||
|
"""
|
||||||
|
Assign indice for matmul op.
|
||||||
|
1. copy trace from matmul_left and change last indice accroding to matmul_right. (assert they have same length)
|
||||||
|
2. mark equal for input matmul_left -1 indice and matmul_right -2 dim.
|
||||||
|
3. inherit matmul_left and matmul_right computation, mark computation for last dim.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
matmul_left, matmul_right = node.args
|
||||||
|
|
||||||
|
assert len(get_node_shape(matmul_left)) == len(get_node_shape(matmul_right))
|
||||||
|
self._assign_indice_as_input(node, node_idx, matmul_left)
|
||||||
|
self._inherit_indice(matmul_right, -1, node, -1)
|
||||||
|
|
||||||
|
self._mark_computation_from_node(matmul_right, node, [-1, -2])
|
||||||
|
self._mark_computation(node, node_idx, [-1])
|
||||||
|
|
||||||
|
def _assign_layernorm_indice(self, node, idx):
|
||||||
|
"""
|
||||||
|
Assign indice for layernorm op.
|
||||||
|
1. assign indice as input node
|
||||||
|
2. inherit computation and mark last 2 dims as computed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
self._assign_indice_as_input(node, idx)
|
||||||
|
self._mark_computation(node, idx, [-1])
|
||||||
|
|
||||||
|
def _assign_elementwise_indice(self, node, idx):
|
||||||
|
"""
|
||||||
|
Assign indice for element-wise op (eg. relu sigmoid add mul).
|
||||||
|
1. assign indice as input node
|
||||||
|
2. inherit computation from all input nodes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
self._assign_indice_as_input(node, idx)
|
||||||
|
nodes_in = []
|
||||||
|
for node_in in node.args:
|
||||||
|
if type(node_in) == type(node):
|
||||||
|
nodes_in.append(node_in)
|
||||||
|
self._mark_computation_from_node(node_in, node)
|
||||||
|
assert len(nodes_in) <= 2
|
||||||
|
|
||||||
|
def _assgin_no_change_indice(self, node, idx):
|
||||||
|
self._assign_indice_as_input(node, idx)
|
||||||
|
for node_in in node.args:
|
||||||
|
if type(node_in) == type(node):
|
||||||
|
self._mark_computation_from_node(node_in, node)
|
||||||
|
|
||||||
|
def _assign_einsum_indice(self, node, idx):
|
||||||
|
"""
|
||||||
|
Assign indice for einsum op.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
patterns = node.args[0]
|
||||||
|
input_nodes = node.args[1:]
|
||||||
|
|
||||||
|
patterns = patterns.replace(" ", "")
|
||||||
|
left, right = patterns.split("->")
|
||||||
|
left = left.split(",")
|
||||||
|
|
||||||
|
all_index = []
|
||||||
|
for i in left:
|
||||||
|
for c in i:
|
||||||
|
all_index.append(c)
|
||||||
|
all_index = set(all_index)
|
||||||
|
|
||||||
|
for right_idx, right_indice in enumerate(right):
|
||||||
|
for left_idx, left_str in enumerate(left):
|
||||||
|
if right_indice in left_str:
|
||||||
|
source_idx = left_str.index(right_indice)
|
||||||
|
self._inherit_indice(
|
||||||
|
input_nodes[left_idx], source_idx, node, right_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
def _assign_softmax_indice(self, node, idx):
|
||||||
|
"""
|
||||||
|
Assign indice for softmax op.
|
||||||
|
1. assign indice as input node
|
||||||
|
2. inherit computation and mark softmax dim as computed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
self._assign_indice_as_input(node, idx)
|
||||||
|
self._mark_computation(node, idx, [node.kwargs["dim"]])
|
||||||
|
|
||||||
|
def _assign_unsqueeze_indice(self, node, node_idx):
|
||||||
|
"""
|
||||||
|
Assign indice for unsqueeze op.
|
||||||
|
1. assign new indice for unsqueeze dim
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
self._del_dim(node_idx, -1)
|
||||||
|
self._assign_indice_as_input(node, node_idx)
|
||||||
|
self._add_dim(node_idx, node.args[1])
|
||||||
|
|
||||||
|
def _assign_dropout_indice(self, node, node_idx):
|
||||||
|
"""
|
||||||
|
Assign indice for unsqueeze op.
|
||||||
|
1. assign new indice for unsqueeze dim
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
self._assign_indice_as_input(node, node_idx)
|
||||||
|
|
||||||
|
def _assign_ones_like_indice(self, node, node_idx):
|
||||||
|
"""
|
||||||
|
Assign indice for oneslike op.
|
||||||
|
1. assign new indice for all dim
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
self._assign_all_indice(node, node_idx)
|
||||||
|
|
||||||
|
def _assign_view_reshape_indice(self, node, node_idx):
|
||||||
|
"""
|
||||||
|
Assign indice for view and reshape op.
|
||||||
|
1. get origin shape and target shape by meta info.
|
||||||
|
2. compute the real value of -1 in target shape.
|
||||||
|
3. determine changed dim, and assgin indice for generated dim.
|
||||||
|
4. log changed dim and generated dim for restore
|
||||||
|
5. inherit computation.
|
||||||
|
6. TODO: look into view list to see whether the view is associated with other,
|
||||||
|
if so assgin equal dim according to previous view.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node (node)
|
||||||
|
node_idx (int)
|
||||||
|
"""
|
||||||
|
# get data, turn into number
|
||||||
|
origin_node = node.args[0]
|
||||||
|
origin_shape = origin_node.meta["tensor_meta"].shape
|
||||||
|
target_shape = []
|
||||||
|
for i in range(1, len(node.args)):
|
||||||
|
if isinstance(node.args[i], int):
|
||||||
|
target_shape.append(node.args[i])
|
||||||
|
else:
|
||||||
|
target_shape.append(node.args[i].meta["fwd_out"][0])
|
||||||
|
|
||||||
|
# compute the value of -1
|
||||||
|
if -1 in target_shape:
|
||||||
|
origin_product = 1
|
||||||
|
for i in origin_shape:
|
||||||
|
origin_product *= i
|
||||||
|
target_product = -1
|
||||||
|
for i in target_shape:
|
||||||
|
target_product *= i
|
||||||
|
shape_idx = target_shape.index(-1)
|
||||||
|
target_shape[shape_idx] = origin_product // target_product
|
||||||
|
|
||||||
|
# determine changed dim
|
||||||
|
len_diff = len(origin_shape) - len(target_shape)
|
||||||
|
if len_diff == 1:
|
||||||
|
# dim merge
|
||||||
|
dim_equal = [i == j for i, j in zip(origin_shape[:-1], target_shape)]
|
||||||
|
dim_to = [dim_equal.index(False)]
|
||||||
|
dim_from = [dim_equal.index(False), dim_equal.index(False) + 1]
|
||||||
|
self._add_dim(node_idx, -1)
|
||||||
|
elif len_diff == -1:
|
||||||
|
# dim expand
|
||||||
|
dim_equal = [i == j for i, j in zip(origin_shape, target_shape[:-1])]
|
||||||
|
dim_from = [dim_equal.index(False)]
|
||||||
|
dim_to = [dim_equal.index(False), dim_equal.index(False) + 1]
|
||||||
|
self._del_dim(node_idx, -1)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"shape"
|
||||||
|
+ str(origin_shape)
|
||||||
|
+ "and"
|
||||||
|
+ str(target_shape)
|
||||||
|
+ "view not implemented"
|
||||||
|
)
|
||||||
|
|
||||||
|
# get new indice
|
||||||
|
origin_trace = self._find_indice_trace_from_node(origin_node)
|
||||||
|
self._assign_indice_as_input(node, node_idx, origin_node)
|
||||||
|
dim_from.reverse()
|
||||||
|
for i in dim_from:
|
||||||
|
self._del_dim(node_idx, i)
|
||||||
|
for i in dim_to:
|
||||||
|
self._add_dim(node_idx, i)
|
||||||
|
|
||||||
|
# inherit computation
|
||||||
|
compute_log = self._find_compute_trace_from_node(origin_node)
|
||||||
|
for i in dim_from:
|
||||||
|
if origin_trace[i] in compute_log:
|
||||||
|
for j in dim_to:
|
||||||
|
self._mark_computation(node, node_idx, [j])
|
||||||
|
break
|
||||||
|
|
||||||
|
# log view, not used now
|
||||||
|
view_dict = {
|
||||||
|
"idx_from": [origin_trace[i] for i in dim_from],
|
||||||
|
"dim_from": dim_from,
|
||||||
|
"idx_to": [self.indice_trace_list[node_idx]["indice"][i] for i in dim_to],
|
||||||
|
"dim_to": dim_to,
|
||||||
|
}
|
||||||
|
self.indice_view_list[node] = view_dict
|
||||||
|
|
||||||
|
def trace_indice(self):
|
||||||
|
for idx, node in enumerate(self.node_list):
|
||||||
|
if node.op == "placeholder":
|
||||||
|
self._assign_all_indice(node, idx)
|
||||||
|
elif node.op == "call_method":
|
||||||
|
if "transpose" in node.name:
|
||||||
|
self._assign_transpose_indice(node, idx)
|
||||||
|
elif "permute" in node.name:
|
||||||
|
self._assign_permute_indice(node, idx)
|
||||||
|
elif "view" in node.name or "reshape" in node.name:
|
||||||
|
self._assign_view_reshape_indice(node, idx)
|
||||||
|
elif "unsqueeze" in node.name:
|
||||||
|
self._assign_unsqueeze_indice(node, idx)
|
||||||
|
elif any(i in node.name for i in ["to", "contiguous"]):
|
||||||
|
self._assgin_no_change_indice(node, idx)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(node.name, "method not implemented yet!")
|
||||||
|
elif node.op == "call_function":
|
||||||
|
if "linear" in node.name:
|
||||||
|
self._assign_linear_indice(node, idx)
|
||||||
|
elif "matmul" in node.name:
|
||||||
|
self._assign_matmul_indice(node, idx)
|
||||||
|
elif "softmax" in node.name:
|
||||||
|
self._assign_softmax_indice(node, idx)
|
||||||
|
elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu"]):
|
||||||
|
self._assign_elementwise_indice(node, idx)
|
||||||
|
elif "ones_like" in node.name:
|
||||||
|
self._assign_ones_like_indice(node, idx)
|
||||||
|
elif "dropout" in node.name:
|
||||||
|
self._assign_dropout_indice(node, idx)
|
||||||
|
elif "einsum" in node.name:
|
||||||
|
self._assign_einsum_indice(node, idx)
|
||||||
|
elif "getattr" in node.name:
|
||||||
|
continue # get attr like shape
|
||||||
|
elif "getitem" in node.name:
|
||||||
|
continue # get item in list
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
node.name, "function not implemented yet!"
|
||||||
|
)
|
||||||
|
elif node.op == "call_module":
|
||||||
|
if any(n in node.name for n in ["layernorm", "norm"]):
|
||||||
|
self._assign_layernorm_indice(node, idx)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(node.name, "module not implemented yet!")
|
||||||
|
elif node.op == "get_attr":
|
||||||
|
self._assign_all_indice(node, idx) # get param
|
||||||
|
elif node.op == "output":
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(node.op, "op not implemented yet!")
|
|
@ -0,0 +1,95 @@
|
||||||
|
from typing import Any, Callable, Dict, Iterable, List, Tuple
|
||||||
|
|
||||||
|
from torch.fx.node import Node
|
||||||
|
|
||||||
|
|
||||||
|
def is_non_compute_node(node):
|
||||||
|
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(
|
||||||
|
i in node.name for i in ["getitem", "getattr"]
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_node_shape(node):
|
||||||
|
if hasattr(node.meta["tensor_meta"], "shape"):
|
||||||
|
return node.meta["tensor_meta"].shape
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def is_non_compute_node_except_placeholder(node):
|
||||||
|
if any(i in node.op for i in ["get_attr", "output"]) or any(
|
||||||
|
i in node.name for i in ["getitem", "getattr"]
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def is_non_compute_node_except_placeholder_output(node):
|
||||||
|
if any(i in node.op for i in ["get_attr"]) or any(
|
||||||
|
i in node.name for i in ["getitem", "getattr"]
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def find_idx_by_name(name, nodes_list):
|
||||||
|
for idx, node in enumerate(nodes_list):
|
||||||
|
if node.name == name:
|
||||||
|
return idx
|
||||||
|
raise RuntimeError("name %s not found in node list" % name)
|
||||||
|
|
||||||
|
|
||||||
|
def delete_free_var_from_last_use(user_to_last_uses):
|
||||||
|
for key, value in user_to_last_uses.items():
|
||||||
|
for n in value:
|
||||||
|
if n.op == "placeholder":
|
||||||
|
user_to_last_uses[key].remove(n)
|
||||||
|
|
||||||
|
|
||||||
|
def find_chunk_all_input_nodes(nodes: List[Node]):
|
||||||
|
"""
|
||||||
|
Find non-compute input and output node names.
|
||||||
|
input nodes are nodes used in the list
|
||||||
|
output nodes are nodes will use nodes in the list
|
||||||
|
"""
|
||||||
|
input_nodes = []
|
||||||
|
for node in nodes:
|
||||||
|
for input_node in node._input_nodes.keys():
|
||||||
|
if input_node not in nodes and input_node not in input_nodes:
|
||||||
|
input_nodes.append(input_node)
|
||||||
|
return input_nodes
|
||||||
|
|
||||||
|
|
||||||
|
def find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
|
||||||
|
"""
|
||||||
|
Find non-compute input and output node names.
|
||||||
|
input nodes are nodes used in the list
|
||||||
|
output nodes are nodes will use nodes in the list
|
||||||
|
"""
|
||||||
|
input_nodes = []
|
||||||
|
output_nodes = []
|
||||||
|
|
||||||
|
# if a node has an input node which is not in the node list
|
||||||
|
# we treat that input node as the input of the checkpoint function
|
||||||
|
for node in nodes:
|
||||||
|
for input_node in node._input_nodes.keys():
|
||||||
|
if (
|
||||||
|
input_node not in nodes
|
||||||
|
and input_node not in input_nodes
|
||||||
|
and not is_non_compute_node_except_placeholder(input_node)
|
||||||
|
):
|
||||||
|
input_nodes.append(input_node)
|
||||||
|
|
||||||
|
# if a node has a user node which is not in the node list
|
||||||
|
# we treat that user node as the node receiving the current node output
|
||||||
|
for node in nodes:
|
||||||
|
for output_node in node.users.keys():
|
||||||
|
if (
|
||||||
|
output_node not in nodes
|
||||||
|
and node not in output_nodes
|
||||||
|
and not is_non_compute_node_except_placeholder_output(output_node)
|
||||||
|
):
|
||||||
|
output_nodes.append(node)
|
||||||
|
|
||||||
|
return input_nodes, output_nodes
|
|
@ -0,0 +1,122 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.fx
|
||||||
|
|
||||||
|
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||||
|
from colossalai.fx import ColoTracer
|
||||||
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
|
from colossalai.fx.profiler import MetaTensor
|
||||||
|
from tests.test_autochunk.evoformer.evoformer import evoformer_base
|
||||||
|
from tests.test_autochunk.openfold.evoformer import EvoformerBlock
|
||||||
|
|
||||||
|
|
||||||
|
def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None):
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||||
|
|
||||||
|
loop = 3
|
||||||
|
with torch.no_grad():
|
||||||
|
for _ in range(loop // 2 + 1):
|
||||||
|
if chunk_size:
|
||||||
|
model(node, pair, chunk_size)
|
||||||
|
else:
|
||||||
|
model(node, pair)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
time1 = time.time()
|
||||||
|
for _ in range(loop):
|
||||||
|
if chunk_size:
|
||||||
|
model(node, pair, chunk_size)
|
||||||
|
else:
|
||||||
|
model(node, pair)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
time2 = time.time()
|
||||||
|
|
||||||
|
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||||
|
print(
|
||||||
|
"%s: time %.4fs, mem %dMB"
|
||||||
|
% (title, (time2 - time1) / loop, new_max_mem - now_mem)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_autochunk(model, max_memory, node, pair):
|
||||||
|
# trace the module and replace codegen
|
||||||
|
graph = ColoTracer().trace(
|
||||||
|
model,
|
||||||
|
meta_args={
|
||||||
|
"node": node.to(torch.device("meta")),
|
||||||
|
"pair": pair.to(torch.device("meta")),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
|
||||||
|
interp = MetaInfoProp(gm_prop)
|
||||||
|
interp.propagate(
|
||||||
|
MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
|
||||||
|
)
|
||||||
|
|
||||||
|
# now run it twice to get meta info in graph module, not necessary
|
||||||
|
gm = torch.fx.GraphModule(model, graph)
|
||||||
|
interp = MetaInfoProp(gm)
|
||||||
|
interp.propagate(
|
||||||
|
MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
|
||||||
|
)
|
||||||
|
|
||||||
|
# set code_gen
|
||||||
|
codegen = AutoChunkCodeGen(gm_prop, max_memory, print_mem=False)
|
||||||
|
graph.set_codegen(codegen)
|
||||||
|
gm = ColoGraphModule(model, graph)
|
||||||
|
gm.recompile()
|
||||||
|
|
||||||
|
# print
|
||||||
|
# code = graph.python_code("self").src
|
||||||
|
# print(code)
|
||||||
|
return gm
|
||||||
|
|
||||||
|
|
||||||
|
def _build_openfold():
|
||||||
|
model = EvoformerBlock(
|
||||||
|
c_m=256,
|
||||||
|
c_z=128,
|
||||||
|
c_hidden_msa_att=32,
|
||||||
|
c_hidden_opm=32,
|
||||||
|
c_hidden_mul=128,
|
||||||
|
c_hidden_pair_att=32,
|
||||||
|
no_heads_msa=8,
|
||||||
|
no_heads_pair=4,
|
||||||
|
transition_n=4,
|
||||||
|
msa_dropout=0.15,
|
||||||
|
pair_dropout=0.15,
|
||||||
|
inf=1e4,
|
||||||
|
eps=1e-4,
|
||||||
|
is_multimer=False,
|
||||||
|
).cuda()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def benchmark_evoformer():
|
||||||
|
# init data and model
|
||||||
|
msa_len = 256
|
||||||
|
pair_len = 512
|
||||||
|
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||||
|
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||||
|
model = evoformer_base().cuda()
|
||||||
|
|
||||||
|
# build autochunk model
|
||||||
|
# max_memory = 1000 # MB, fit memory mode
|
||||||
|
max_memory = None # min memory mode
|
||||||
|
autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair)
|
||||||
|
|
||||||
|
# build openfold
|
||||||
|
chunk_size = 64
|
||||||
|
openfold = _build_openfold()
|
||||||
|
|
||||||
|
# benchmark
|
||||||
|
_benchmark_evoformer(model, node, pair, "base")
|
||||||
|
_benchmark_evoformer(openfold, node, pair, "openfold", chunk_size=chunk_size)
|
||||||
|
_benchmark_evoformer(autochunk, node, pair, "autochunk")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
benchmark_evoformer()
|
|
@ -0,0 +1,59 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .msa import MSAStack
|
||||||
|
from .ops import OutProductMean
|
||||||
|
from .triangle import PairStack
|
||||||
|
|
||||||
|
|
||||||
|
def print_memory(init_mem, text=None):
|
||||||
|
now_mem = torch.cuda.memory_allocated() / 1024 ** 2 - init_mem
|
||||||
|
max_mem = torch.cuda.max_memory_allocated() / 1024 ** 2 - init_mem
|
||||||
|
print("%s now:%.2f max:%.2f" % ("" if text is None else text, now_mem, max_mem))
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
|
||||||
|
|
||||||
|
class EvoformerBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_node, d_pair):
|
||||||
|
super(EvoformerBlock, self).__init__()
|
||||||
|
|
||||||
|
self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15)
|
||||||
|
self.communication = OutProductMean(n_feat=d_node, n_feat_out=d_pair, n_feat_proj=32)
|
||||||
|
self.pair_stack = PairStack(d_pair=d_pair)
|
||||||
|
|
||||||
|
def forward(self, node, pair):
|
||||||
|
node = self.msa_stack(node, pair)
|
||||||
|
pair = pair + self.communication(node)
|
||||||
|
pair = self.pair_stack(pair)
|
||||||
|
return node, pair
|
||||||
|
|
||||||
|
|
||||||
|
class Evoformer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_node, d_pair):
|
||||||
|
super(Evoformer, self).__init__()
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList()
|
||||||
|
for _ in range(1):
|
||||||
|
self.blocks.append(EvoformerBlock(d_node, d_pair))
|
||||||
|
|
||||||
|
def forward(self, node, pair):
|
||||||
|
for b in self.blocks:
|
||||||
|
node, pair = b(node, pair)
|
||||||
|
return node, pair
|
||||||
|
|
||||||
|
|
||||||
|
def evoformer_tiny():
|
||||||
|
return Evoformer(d_node=64, d_pair=32)
|
||||||
|
|
||||||
|
|
||||||
|
def evoformer_base():
|
||||||
|
return Evoformer(d_node=256, d_pair=128)
|
||||||
|
|
||||||
|
|
||||||
|
def evoformer_large():
|
||||||
|
return Evoformer(d_node=512, d_pair=256)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['Evoformer', 'evoformer_base', 'evoformer_large']
|
|
@ -0,0 +1,29 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def glorot_uniform_af(x, gain=1.0):
|
||||||
|
"""
|
||||||
|
initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different:
|
||||||
|
In PyTorch:
|
||||||
|
[feature_out, feature_in, n_head ...]
|
||||||
|
In Jax:
|
||||||
|
[... n_head, feature_in, feature_out]
|
||||||
|
However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like:
|
||||||
|
[feature_in, n_head, feature_out]
|
||||||
|
|
||||||
|
In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors
|
||||||
|
"""
|
||||||
|
fan_in, fan_out = x.shape[-2:]
|
||||||
|
if len(x.shape) > 2:
|
||||||
|
receptive_field_size = np.prod(x.shape[:-2])
|
||||||
|
fan_in *= receptive_field_size
|
||||||
|
fan_out *= receptive_field_size
|
||||||
|
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
|
||||||
|
dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
||||||
|
|
||||||
|
nn.init.uniform_(x, -dev, dev)
|
||||||
|
|
||||||
|
return x
|
|
@ -0,0 +1,19 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def bias_sigmod_ele(y, bias, z):
|
||||||
|
return torch.sigmoid(y + bias) * z
|
||||||
|
|
||||||
|
|
||||||
|
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor,
|
||||||
|
residual: torch.Tensor, prob: float) -> torch.Tensor:
|
||||||
|
out = (x + bias) * F.dropout(dropmask, p=prob, training=False)
|
||||||
|
out = residual + out
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor,
|
||||||
|
dropout_mask: torch.Tensor, Z_raw: torch.Tensor,
|
||||||
|
prob: float) -> torch.Tensor:
|
||||||
|
return Z_raw + F.dropout(dropout_mask, p=prob, training=True) * (g * (ab + b))
|
|
@ -0,0 +1,95 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
from torch.nn import LayerNorm
|
||||||
|
|
||||||
|
from .kernel import bias_dropout_add
|
||||||
|
from .ops import SelfAttention, Transition
|
||||||
|
|
||||||
|
|
||||||
|
class MSARowAttentionWithPairBias(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15):
|
||||||
|
super(MSARowAttentionWithPairBias, self).__init__()
|
||||||
|
self.d_node = d_node
|
||||||
|
self.d_pair = d_pair
|
||||||
|
self.c = c
|
||||||
|
self.n_head = n_head
|
||||||
|
self.p_drop = p_drop
|
||||||
|
|
||||||
|
self.layernormM = LayerNorm(d_node)
|
||||||
|
self.layernormZ = LayerNorm(d_pair)
|
||||||
|
|
||||||
|
_init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]),
|
||||||
|
std=1.0 / math.sqrt(d_pair))
|
||||||
|
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True)
|
||||||
|
|
||||||
|
self.attention = SelfAttention(qkv_dim=d_node,
|
||||||
|
c=c,
|
||||||
|
n_head=n_head,
|
||||||
|
out_dim=d_node,
|
||||||
|
gating=True,
|
||||||
|
last_bias_fuse=True)
|
||||||
|
|
||||||
|
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True)
|
||||||
|
|
||||||
|
def forward(self, M_raw, Z):
|
||||||
|
## Input projections
|
||||||
|
M = self.layernormM(M_raw)
|
||||||
|
Z = self.layernormZ(Z)
|
||||||
|
b = F.linear(Z, self.linear_b_weights)
|
||||||
|
b = b.permute(0, 3, 1, 2)
|
||||||
|
# b = rearrange(b, 'b q k h -> b h q k')
|
||||||
|
|
||||||
|
M = self.attention(M, b)
|
||||||
|
dropout_mask = torch.ones_like(M[:, 0:1, :, :]).to(M.device).to(M.dtype)
|
||||||
|
|
||||||
|
return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop)
|
||||||
|
|
||||||
|
|
||||||
|
class MSAColumnAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_node, c=32, n_head=8):
|
||||||
|
super(MSAColumnAttention, self).__init__()
|
||||||
|
self.d_node = d_node
|
||||||
|
self.c = c
|
||||||
|
self.n_head = n_head
|
||||||
|
|
||||||
|
self.layernormM = LayerNorm(d_node)
|
||||||
|
self.attention = SelfAttention(qkv_dim=d_node,
|
||||||
|
c=c,
|
||||||
|
n_head=n_head,
|
||||||
|
out_dim=d_node,
|
||||||
|
gating=True)
|
||||||
|
|
||||||
|
def forward(self, M_raw):
|
||||||
|
M = M_raw.transpose(-2, -3)
|
||||||
|
M = self.layernormM(M)
|
||||||
|
|
||||||
|
M = self.attention(M)
|
||||||
|
|
||||||
|
M = M.transpose(-2, -3)
|
||||||
|
return M_raw + M
|
||||||
|
|
||||||
|
|
||||||
|
class MSAStack(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_node, d_pair, p_drop=0.15):
|
||||||
|
super(MSAStack, self).__init__()
|
||||||
|
|
||||||
|
self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node,
|
||||||
|
d_pair=d_pair,
|
||||||
|
p_drop=p_drop)
|
||||||
|
|
||||||
|
self.MSAColumnAttention = MSAColumnAttention(d_node=d_node)
|
||||||
|
self.MSATransition = Transition(d=d_node)
|
||||||
|
|
||||||
|
def forward(self, node, pair):
|
||||||
|
node = self.MSARowAttentionWithPairBias(node, pair)
|
||||||
|
node = self.MSAColumnAttention(node)
|
||||||
|
node = self.MSATransition(node)
|
||||||
|
|
||||||
|
return node
|
|
@ -0,0 +1,176 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
from torch.nn import LayerNorm
|
||||||
|
|
||||||
|
from .initializer import glorot_uniform_af
|
||||||
|
from .kernel import bias_sigmod_ele
|
||||||
|
|
||||||
|
|
||||||
|
class DropoutRowwise(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, p):
|
||||||
|
super(DropoutRowwise, self).__init__()
|
||||||
|
self.p = p
|
||||||
|
self.dropout = nn.Dropout(p=p)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
dropout_mask = torch.ones_like(x[:, 0:1, :, :])
|
||||||
|
dropout_mask = self.dropout(dropout_mask)
|
||||||
|
return dropout_mask * x
|
||||||
|
|
||||||
|
|
||||||
|
class DropoutColumnwise(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, p):
|
||||||
|
super(DropoutColumnwise, self).__init__()
|
||||||
|
self.p = p
|
||||||
|
self.dropout = nn.Dropout(p=p)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
dropout_mask = torch.ones_like(x[:, :, 0:1, :])
|
||||||
|
dropout_mask = self.dropout(dropout_mask)
|
||||||
|
return dropout_mask * x
|
||||||
|
|
||||||
|
|
||||||
|
class Transition(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d, n=4):
|
||||||
|
super(Transition, self).__init__()
|
||||||
|
self.norm = LayerNorm(d)
|
||||||
|
self.linear1 = Linear(d, n * d, initializer='relu')
|
||||||
|
self.linear2 = Linear(n * d, d, initializer='zeros')
|
||||||
|
|
||||||
|
def forward(self, src):
|
||||||
|
x = self.norm(src)
|
||||||
|
x = self.linear2(F.relu(self.linear1(x)))
|
||||||
|
return src + x
|
||||||
|
|
||||||
|
|
||||||
|
class OutProductMean(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32):
|
||||||
|
super(OutProductMean, self).__init__()
|
||||||
|
|
||||||
|
self.layernormM = LayerNorm(n_feat)
|
||||||
|
self.linear_a = Linear(n_feat, n_feat_proj)
|
||||||
|
self.linear_b = Linear(n_feat, n_feat_proj)
|
||||||
|
|
||||||
|
self.o_linear = Linear(n_feat_proj * n_feat_proj,
|
||||||
|
n_feat_out,
|
||||||
|
initializer='zero',
|
||||||
|
use_bias=True)
|
||||||
|
|
||||||
|
def forward(self, M):
|
||||||
|
M = self.layernormM(M)
|
||||||
|
left_act = self.linear_a(M)
|
||||||
|
right_act = self.linear_b(M)
|
||||||
|
|
||||||
|
o = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous()
|
||||||
|
# O = rearrange(O, 'b i j d e -> b i j (d e)')
|
||||||
|
o = o.reshape(o.shape[0], o.shape[1], o.shape[2], -1)
|
||||||
|
Z = self.o_linear(o)
|
||||||
|
|
||||||
|
return Z
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(nn.Linear):
|
||||||
|
"""
|
||||||
|
A Linear layer with built-in nonstandard initializations. Called just
|
||||||
|
like torch.nn.Linear.
|
||||||
|
Implements the initializers in 1.11.4, plus some additional ones found
|
||||||
|
in the code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
feature_in: int,
|
||||||
|
feature_out: int,
|
||||||
|
initializer: str = 'linear',
|
||||||
|
use_bias: bool = True,
|
||||||
|
bias_init: float = 0.,
|
||||||
|
):
|
||||||
|
super(Linear, self).__init__(feature_in, feature_out, bias=use_bias)
|
||||||
|
|
||||||
|
self.use_bias = use_bias
|
||||||
|
if initializer == 'linear':
|
||||||
|
glorot_uniform_af(self.weight, gain=1.0)
|
||||||
|
elif initializer == 'relu':
|
||||||
|
glorot_uniform_af(self.weight, gain=2.0)
|
||||||
|
elif initializer == 'zeros':
|
||||||
|
nn.init.zeros_(self.weight)
|
||||||
|
if self.use_bias:
|
||||||
|
with torch.no_grad():
|
||||||
|
self.bias.fill_(bias_init)
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False):
|
||||||
|
super(SelfAttention, self).__init__()
|
||||||
|
self.qkv_dim = qkv_dim
|
||||||
|
self.c = c
|
||||||
|
self.n_head = n_head
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.gating = gating
|
||||||
|
self.last_bias_fuse = last_bias_fuse
|
||||||
|
|
||||||
|
self.scaling = self.c**(-0.5)
|
||||||
|
|
||||||
|
# self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear')
|
||||||
|
self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
|
||||||
|
self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
|
||||||
|
self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
|
||||||
|
|
||||||
|
if gating:
|
||||||
|
self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,)))
|
||||||
|
self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False)
|
||||||
|
|
||||||
|
self.o_linear = Linear(n_head * c,
|
||||||
|
out_dim,
|
||||||
|
initializer='zero',
|
||||||
|
use_bias=(not last_bias_fuse))
|
||||||
|
|
||||||
|
def forward(self, in_data, nonbatched_bias=None):
|
||||||
|
"""
|
||||||
|
:param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim]
|
||||||
|
:param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv]
|
||||||
|
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# qkv = self.to_qkv(in_data).chunk(3, dim=-1)
|
||||||
|
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
|
||||||
|
|
||||||
|
q = self.to_q(in_data)
|
||||||
|
k = self.to_k(in_data)
|
||||||
|
v = self.to_v(in_data)
|
||||||
|
|
||||||
|
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head),
|
||||||
|
# [q, k, v])
|
||||||
|
q, k, v = map(lambda t: t.view(t.shape[0], t.shape[1], t.shape[2], self.n_head, -1).permute(0, 1, 3, 2, 4),
|
||||||
|
[q, k, v])
|
||||||
|
|
||||||
|
q = q * self.scaling
|
||||||
|
|
||||||
|
logits = torch.matmul(q, k.transpose(-1, -2))
|
||||||
|
|
||||||
|
if nonbatched_bias is not None:
|
||||||
|
logits += nonbatched_bias.unsqueeze(1)
|
||||||
|
weights = torch.softmax(logits, dim=-1)
|
||||||
|
# weights = softmax(logits)
|
||||||
|
|
||||||
|
weighted_avg = torch.matmul(weights, v)
|
||||||
|
# weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
|
||||||
|
weighted_avg = weighted_avg.permute(0, 1, 3, 2, 4)
|
||||||
|
weighted_avg = weighted_avg.reshape(weighted_avg.shape[0], weighted_avg.shape[1], weighted_avg.shape[2], -1)
|
||||||
|
|
||||||
|
if self.gating:
|
||||||
|
gate_values = self.gating_linear(in_data)
|
||||||
|
weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg)
|
||||||
|
|
||||||
|
output = self.o_linear(weighted_avg)
|
||||||
|
return output
|
|
@ -0,0 +1,192 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import LayerNorm
|
||||||
|
|
||||||
|
from .kernel import bias_dropout_add, bias_ele_dropout_residual
|
||||||
|
from .ops import Linear, SelfAttention, Transition
|
||||||
|
|
||||||
|
|
||||||
|
def permute_final_dims(tensor, inds):
|
||||||
|
zero_index = -1 * len(inds)
|
||||||
|
first_inds = list(range(len(tensor.shape[:zero_index])))
|
||||||
|
return tensor.permute(first_inds + [zero_index + i for i in inds])
|
||||||
|
|
||||||
|
|
||||||
|
class TriangleMultiplicationOutgoing(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_pair, p_drop, c=128):
|
||||||
|
super(TriangleMultiplicationOutgoing, self).__init__()
|
||||||
|
self.d_pair = d_pair
|
||||||
|
self.c = c
|
||||||
|
|
||||||
|
self.layernorm1 = LayerNorm(d_pair)
|
||||||
|
self.left_projection = Linear(d_pair, c)
|
||||||
|
self.right_projection = Linear(d_pair, c)
|
||||||
|
self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||||
|
self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||||
|
|
||||||
|
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
|
||||||
|
self.layernorm2 = LayerNorm(c)
|
||||||
|
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
|
||||||
|
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||||
|
|
||||||
|
self.p_drop = p_drop
|
||||||
|
|
||||||
|
def forward(self, Z_raw):
|
||||||
|
Z = self.layernorm1(Z_raw)
|
||||||
|
left_proj_act = self.left_projection(Z)
|
||||||
|
right_proj_act = self.right_projection(Z)
|
||||||
|
|
||||||
|
left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z))
|
||||||
|
right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z))
|
||||||
|
|
||||||
|
g = torch.sigmoid(self.output_gate(Z))
|
||||||
|
# p = torch.matmul(
|
||||||
|
# permute_final_dims(left_proj_act, (2, 0, 1)),
|
||||||
|
# permute_final_dims(right_proj_act, (2, 1, 0)),
|
||||||
|
# )
|
||||||
|
# ab = permute_final_dims(p, (1, 2, 0))
|
||||||
|
|
||||||
|
ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act)
|
||||||
|
ab = self.output_projection(self.layernorm2(ab))
|
||||||
|
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
|
||||||
|
return bias_ele_dropout_residual(ab,
|
||||||
|
self.output_bias,
|
||||||
|
g,
|
||||||
|
dropout_mask,
|
||||||
|
Z_raw,
|
||||||
|
prob=self.p_drop)
|
||||||
|
|
||||||
|
|
||||||
|
class TriangleMultiplicationIncoming(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_pair, p_drop, c=128):
|
||||||
|
super(TriangleMultiplicationIncoming, self).__init__()
|
||||||
|
self.d_pair = d_pair
|
||||||
|
self.c = c
|
||||||
|
|
||||||
|
self.layernorm1 = LayerNorm(d_pair)
|
||||||
|
self.left_projection = Linear(d_pair, c)
|
||||||
|
self.right_projection = Linear(d_pair, c)
|
||||||
|
self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||||
|
self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||||
|
|
||||||
|
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
|
||||||
|
self.layernorm2 = LayerNorm(c)
|
||||||
|
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
|
||||||
|
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||||
|
|
||||||
|
self.p_drop = p_drop
|
||||||
|
|
||||||
|
def forward(self, Z_raw):
|
||||||
|
Z = self.layernorm1(Z_raw)
|
||||||
|
left_proj_act = self.left_projection(Z)
|
||||||
|
right_proj_act = self.right_projection(Z)
|
||||||
|
|
||||||
|
left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z))
|
||||||
|
right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z))
|
||||||
|
|
||||||
|
g = torch.sigmoid(self.output_gate(Z))
|
||||||
|
# p = torch.matmul(
|
||||||
|
# permute_final_dims(left_proj_act, (2, 1, 0)),
|
||||||
|
# permute_final_dims(right_proj_act, (2, 0, 1)),
|
||||||
|
# )
|
||||||
|
# ab = permute_final_dims(p, (1, 2, 0))
|
||||||
|
|
||||||
|
ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act)
|
||||||
|
ab = self.output_projection(self.layernorm2(ab))
|
||||||
|
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
|
||||||
|
return bias_ele_dropout_residual(ab,
|
||||||
|
self.output_bias,
|
||||||
|
g,
|
||||||
|
dropout_mask,
|
||||||
|
Z_raw,
|
||||||
|
prob=self.p_drop)
|
||||||
|
|
||||||
|
|
||||||
|
class TriangleAttentionStartingNode(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_pair, p_drop, c=32, n_head=4):
|
||||||
|
super(TriangleAttentionStartingNode, self).__init__()
|
||||||
|
self.d_pair = d_pair
|
||||||
|
self.c = c
|
||||||
|
self.n_head = n_head
|
||||||
|
self.p_drop = p_drop
|
||||||
|
|
||||||
|
self.layernorm1 = LayerNorm(d_pair)
|
||||||
|
_init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]),
|
||||||
|
std=1.0 / math.sqrt(d_pair))
|
||||||
|
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
|
||||||
|
self.attention = SelfAttention(qkv_dim=d_pair,
|
||||||
|
c=c,
|
||||||
|
n_head=n_head,
|
||||||
|
out_dim=d_pair,
|
||||||
|
gating=True,
|
||||||
|
last_bias_fuse=True)
|
||||||
|
|
||||||
|
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||||
|
|
||||||
|
def forward(self, Z_raw):
|
||||||
|
Z = self.layernorm1(Z_raw)
|
||||||
|
b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights)
|
||||||
|
|
||||||
|
Z = self.attention(Z, b)
|
||||||
|
|
||||||
|
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
|
||||||
|
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
|
||||||
|
|
||||||
|
|
||||||
|
class TriangleAttentionEndingNode(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_pair, p_drop, c=32, n_head=4):
|
||||||
|
super(TriangleAttentionEndingNode, self).__init__()
|
||||||
|
self.d_pair = d_pair
|
||||||
|
self.c = c
|
||||||
|
self.n_head = n_head
|
||||||
|
self.p_drop = p_drop
|
||||||
|
|
||||||
|
self.layernorm1 = LayerNorm(d_pair)
|
||||||
|
_init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]),
|
||||||
|
std=1.0 / math.sqrt(d_pair))
|
||||||
|
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
|
||||||
|
self.attention = SelfAttention(qkv_dim=d_pair,
|
||||||
|
c=c,
|
||||||
|
n_head=n_head,
|
||||||
|
out_dim=d_pair,
|
||||||
|
gating=True,
|
||||||
|
last_bias_fuse=True)
|
||||||
|
|
||||||
|
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||||
|
|
||||||
|
def forward(self, Z_raw):
|
||||||
|
Z = Z_raw.transpose(-2, -3)
|
||||||
|
Z = self.layernorm1(Z)
|
||||||
|
b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights)
|
||||||
|
|
||||||
|
Z = self.attention(Z, b)
|
||||||
|
|
||||||
|
Z = Z.transpose(-2, -3)
|
||||||
|
dropout_mask = torch.ones_like(Z[:, :, 0:1, :]).to(Z.device).to(Z.dtype)
|
||||||
|
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
|
||||||
|
|
||||||
|
|
||||||
|
class PairStack(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_pair, p_drop=0.25):
|
||||||
|
super(PairStack, self).__init__()
|
||||||
|
|
||||||
|
self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, p_drop=p_drop)
|
||||||
|
self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, p_drop=p_drop)
|
||||||
|
self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, p_drop=p_drop)
|
||||||
|
self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, p_drop=p_drop)
|
||||||
|
self.PairTransition = Transition(d=d_pair)
|
||||||
|
|
||||||
|
def forward(self, pair):
|
||||||
|
pair = self.TriangleMultiplicationOutgoing(pair)
|
||||||
|
pair = self.TriangleMultiplicationIncoming(pair)
|
||||||
|
pair = self.TriangleAttentionStartingNode(pair)
|
||||||
|
pair = self.TriangleAttentionEndingNode(pair)
|
||||||
|
pair = self.PairTransition(pair)
|
||||||
|
return pair
|
|
@ -0,0 +1,84 @@
|
||||||
|
# Copyright 2021 AlQuraishi Laboratory
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from typing import Any, Tuple, List, Callable, Optional
|
||||||
|
|
||||||
|
|
||||||
|
BLOCK_ARG = Any
|
||||||
|
BLOCK_ARGS = List[BLOCK_ARG]
|
||||||
|
|
||||||
|
|
||||||
|
def get_checkpoint_fn():
|
||||||
|
checkpoint = torch.utils.checkpoint.checkpoint
|
||||||
|
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def checkpoint_blocks(
|
||||||
|
blocks: List[Callable],
|
||||||
|
args: BLOCK_ARGS,
|
||||||
|
blocks_per_ckpt: Optional[int],
|
||||||
|
) -> BLOCK_ARGS:
|
||||||
|
"""
|
||||||
|
Chunk a list of blocks and run each chunk with activation
|
||||||
|
checkpointing. We define a "block" as a callable whose only inputs are
|
||||||
|
the outputs of the previous block.
|
||||||
|
|
||||||
|
Implements Subsection 1.11.8
|
||||||
|
|
||||||
|
Args:
|
||||||
|
blocks:
|
||||||
|
List of blocks
|
||||||
|
args:
|
||||||
|
Tuple of arguments for the first block.
|
||||||
|
blocks_per_ckpt:
|
||||||
|
Size of each chunk. A higher value corresponds to fewer
|
||||||
|
checkpoints, and trades memory for speed. If None, no checkpointing
|
||||||
|
is performed.
|
||||||
|
Returns:
|
||||||
|
The output of the final block
|
||||||
|
"""
|
||||||
|
def wrap(a):
|
||||||
|
return (a,) if type(a) is not tuple else a
|
||||||
|
|
||||||
|
def exec(b, a):
|
||||||
|
for block in b:
|
||||||
|
a = wrap(block(*a))
|
||||||
|
return a
|
||||||
|
|
||||||
|
def chunker(s, e):
|
||||||
|
def exec_sliced(*a):
|
||||||
|
return exec(blocks[s:e], a)
|
||||||
|
|
||||||
|
return exec_sliced
|
||||||
|
|
||||||
|
# Avoids mishaps when the blocks take just one argument
|
||||||
|
args = wrap(args)
|
||||||
|
|
||||||
|
if blocks_per_ckpt is None:
|
||||||
|
return exec(blocks, args)
|
||||||
|
elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks):
|
||||||
|
raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)")
|
||||||
|
|
||||||
|
checkpoint = get_checkpoint_fn()
|
||||||
|
|
||||||
|
for s in range(0, len(blocks), blocks_per_ckpt):
|
||||||
|
e = s + blocks_per_ckpt
|
||||||
|
args = checkpoint(chunker(s, e), *args)
|
||||||
|
args = wrap(args)
|
||||||
|
|
||||||
|
return args
|
|
@ -0,0 +1,78 @@
|
||||||
|
# Copyright 2021 AlQuraishi Laboratory
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from functools import partialmethod
|
||||||
|
from typing import Union, List
|
||||||
|
|
||||||
|
|
||||||
|
class Dropout(nn.Module):
|
||||||
|
"""
|
||||||
|
Implementation of dropout with the ability to share the dropout mask
|
||||||
|
along a particular dimension.
|
||||||
|
|
||||||
|
If not in training mode, this module computes the identity function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, r: float, batch_dim: Union[int, List[int]]):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
r:
|
||||||
|
Dropout rate
|
||||||
|
batch_dim:
|
||||||
|
Dimension(s) along which the dropout mask is shared
|
||||||
|
"""
|
||||||
|
super(Dropout, self).__init__()
|
||||||
|
|
||||||
|
self.r = r
|
||||||
|
if type(batch_dim) == int:
|
||||||
|
batch_dim = [batch_dim]
|
||||||
|
self.batch_dim = batch_dim
|
||||||
|
self.dropout = nn.Dropout(self.r)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
Tensor to which dropout is applied. Can have any shape
|
||||||
|
compatible with self.batch_dim
|
||||||
|
"""
|
||||||
|
shape = list(x.shape)
|
||||||
|
if self.batch_dim is not None:
|
||||||
|
for bd in self.batch_dim:
|
||||||
|
shape[bd] = 1
|
||||||
|
mask = x.new_ones(shape)
|
||||||
|
mask = self.dropout(mask)
|
||||||
|
x *= mask
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DropoutRowwise(Dropout):
|
||||||
|
"""
|
||||||
|
Convenience class for rowwise dropout as described in subsection
|
||||||
|
1.11.6.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__init__ = partialmethod(Dropout.__init__, batch_dim=-3)
|
||||||
|
|
||||||
|
|
||||||
|
class DropoutColumnwise(Dropout):
|
||||||
|
"""
|
||||||
|
Convenience class for columnwise dropout as described in subsection
|
||||||
|
1.11.6.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__init__ = partialmethod(Dropout.__init__, batch_dim=-2)
|
|
@ -0,0 +1,431 @@
|
||||||
|
# Copyright 2021 AlQuraishi Laboratory
|
||||||
|
# Copyright 2021 DeepMind Technologies Limited
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from .primitives import Linear, LayerNorm
|
||||||
|
from .dropout import DropoutRowwise, DropoutColumnwise
|
||||||
|
from .msa import (
|
||||||
|
MSARowAttentionWithPairBias,
|
||||||
|
MSAColumnAttention,
|
||||||
|
MSAColumnGlobalAttention,
|
||||||
|
)
|
||||||
|
from .outer_product_mean import OuterProductMean
|
||||||
|
from .pair_transition import PairTransition
|
||||||
|
from .triangular_attention import (
|
||||||
|
TriangleAttentionStartingNode,
|
||||||
|
TriangleAttentionEndingNode,
|
||||||
|
)
|
||||||
|
from .triangular_multiplicative_update import (
|
||||||
|
TriangleMultiplicationOutgoing,
|
||||||
|
TriangleMultiplicationIncoming,
|
||||||
|
)
|
||||||
|
from .checkpointing import checkpoint_blocks, get_checkpoint_fn
|
||||||
|
from .tensor_utils import chunk_layer
|
||||||
|
|
||||||
|
|
||||||
|
class MSATransition(nn.Module):
|
||||||
|
"""
|
||||||
|
Feed-forward network applied to MSA activations after attention.
|
||||||
|
|
||||||
|
Implements Algorithm 9
|
||||||
|
"""
|
||||||
|
def __init__(self, c_m, n):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
c_m:
|
||||||
|
MSA channel dimension
|
||||||
|
n:
|
||||||
|
Factor multiplied to c_m to obtain the hidden channel
|
||||||
|
dimension
|
||||||
|
"""
|
||||||
|
super(MSATransition, self).__init__()
|
||||||
|
|
||||||
|
self.c_m = c_m
|
||||||
|
self.n = n
|
||||||
|
|
||||||
|
self.layer_norm = LayerNorm(self.c_m)
|
||||||
|
self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu")
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
|
||||||
|
|
||||||
|
def _transition(self, m, mask):
|
||||||
|
m = self.linear_1(m)
|
||||||
|
m = self.relu(m)
|
||||||
|
m = self.linear_2(m) * mask
|
||||||
|
return m
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def _chunk(self,
|
||||||
|
m: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
chunk_size: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return chunk_layer(
|
||||||
|
self._transition,
|
||||||
|
{"m": m, "mask": mask},
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
no_batch_dims=len(m.shape[:-2]),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
m: torch.Tensor,
|
||||||
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
chunk_size: Optional[int] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
m:
|
||||||
|
[*, N_seq, N_res, C_m] MSA activation
|
||||||
|
mask:
|
||||||
|
[*, N_seq, N_res, C_m] MSA mask
|
||||||
|
Returns:
|
||||||
|
m:
|
||||||
|
[*, N_seq, N_res, C_m] MSA activation update
|
||||||
|
"""
|
||||||
|
|
||||||
|
# DISCREPANCY: DeepMind forgets to apply the MSA mask here.
|
||||||
|
if mask is None:
|
||||||
|
mask = m.new_ones(m.shape[:-1])
|
||||||
|
|
||||||
|
# [*, N_seq, N_res, 1]
|
||||||
|
mask = mask.unsqueeze(-1)
|
||||||
|
|
||||||
|
m = self.layer_norm(m)
|
||||||
|
|
||||||
|
if chunk_size is not None:
|
||||||
|
m = self._chunk(m, mask, chunk_size)
|
||||||
|
else:
|
||||||
|
m = self._transition(m, mask)
|
||||||
|
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
class EvoformerBlockCore(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
c_m: int,
|
||||||
|
c_z: int,
|
||||||
|
c_hidden_opm: int,
|
||||||
|
c_hidden_mul: int,
|
||||||
|
c_hidden_pair_att: int,
|
||||||
|
no_heads_msa: int,
|
||||||
|
no_heads_pair: int,
|
||||||
|
transition_n: int,
|
||||||
|
pair_dropout: float,
|
||||||
|
inf: float,
|
||||||
|
eps: float,
|
||||||
|
_is_extra_msa_stack: bool = False,
|
||||||
|
is_multimer: bool = False,
|
||||||
|
):
|
||||||
|
super(EvoformerBlockCore, self).__init__()
|
||||||
|
self.is_multimer = is_multimer
|
||||||
|
self.msa_transition = MSATransition(
|
||||||
|
c_m=c_m,
|
||||||
|
n=transition_n,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.outer_product_mean = OuterProductMean(
|
||||||
|
c_m,
|
||||||
|
c_z,
|
||||||
|
c_hidden_opm,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tri_mul_out = TriangleMultiplicationOutgoing(
|
||||||
|
c_z,
|
||||||
|
c_hidden_mul,
|
||||||
|
)
|
||||||
|
self.tri_mul_in = TriangleMultiplicationIncoming(
|
||||||
|
c_z,
|
||||||
|
c_hidden_mul,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.tri_att_start = TriangleAttentionStartingNode(
|
||||||
|
c_z,
|
||||||
|
c_hidden_pair_att,
|
||||||
|
no_heads_pair,
|
||||||
|
inf=inf,
|
||||||
|
)
|
||||||
|
self.tri_att_end = TriangleAttentionEndingNode(
|
||||||
|
c_z,
|
||||||
|
c_hidden_pair_att,
|
||||||
|
no_heads_pair,
|
||||||
|
inf=inf,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pair_transition = PairTransition(
|
||||||
|
c_z,
|
||||||
|
transition_n,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
|
||||||
|
self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
m: torch.Tensor,
|
||||||
|
z: torch.Tensor,
|
||||||
|
chunk_size: Optional[int] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
# DeepMind doesn't mask these transitions in the source, so _mask_trans
|
||||||
|
# should be disabled to better approximate the exact activations of
|
||||||
|
# the original.
|
||||||
|
|
||||||
|
m = m + self.msa_transition(
|
||||||
|
m, chunk_size=chunk_size
|
||||||
|
)
|
||||||
|
z = z + self.outer_product_mean(
|
||||||
|
m, chunk_size=chunk_size
|
||||||
|
)
|
||||||
|
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z))
|
||||||
|
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z))
|
||||||
|
z = z + self.ps_dropout_row_layer(
|
||||||
|
self.tri_att_start(z, chunk_size=chunk_size)
|
||||||
|
)
|
||||||
|
z = z + self.ps_dropout_col_layer(
|
||||||
|
self.tri_att_end(z, chunk_size=chunk_size)
|
||||||
|
)
|
||||||
|
z = z + self.pair_transition(
|
||||||
|
z, chunk_size=chunk_size
|
||||||
|
)
|
||||||
|
|
||||||
|
return m, z
|
||||||
|
|
||||||
|
|
||||||
|
class EvoformerBlock(nn.Module):
|
||||||
|
def __init__(self,
|
||||||
|
c_m: int,
|
||||||
|
c_z: int,
|
||||||
|
c_hidden_msa_att: int,
|
||||||
|
c_hidden_opm: int,
|
||||||
|
c_hidden_mul: int,
|
||||||
|
c_hidden_pair_att: int,
|
||||||
|
no_heads_msa: int,
|
||||||
|
no_heads_pair: int,
|
||||||
|
transition_n: int,
|
||||||
|
msa_dropout: float,
|
||||||
|
pair_dropout: float,
|
||||||
|
inf: float,
|
||||||
|
eps: float,
|
||||||
|
is_multimer: bool,
|
||||||
|
):
|
||||||
|
super(EvoformerBlock, self).__init__()
|
||||||
|
|
||||||
|
self.msa_att_row = MSARowAttentionWithPairBias(
|
||||||
|
c_m=c_m,
|
||||||
|
c_z=c_z,
|
||||||
|
c_hidden=c_hidden_msa_att,
|
||||||
|
no_heads=no_heads_msa,
|
||||||
|
inf=inf,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.msa_att_col = MSAColumnAttention(
|
||||||
|
c_m,
|
||||||
|
c_hidden_msa_att,
|
||||||
|
no_heads_msa,
|
||||||
|
inf=inf,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
|
||||||
|
|
||||||
|
self.core = EvoformerBlockCore(
|
||||||
|
c_m=c_m,
|
||||||
|
c_z=c_z,
|
||||||
|
c_hidden_opm=c_hidden_opm,
|
||||||
|
c_hidden_mul=c_hidden_mul,
|
||||||
|
c_hidden_pair_att=c_hidden_pair_att,
|
||||||
|
no_heads_msa=no_heads_msa,
|
||||||
|
no_heads_pair=no_heads_pair,
|
||||||
|
transition_n=transition_n,
|
||||||
|
pair_dropout=pair_dropout,
|
||||||
|
inf=inf,
|
||||||
|
eps=eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.outer_product_mean = OuterProductMean(
|
||||||
|
c_m,
|
||||||
|
c_z,
|
||||||
|
c_hidden_opm,
|
||||||
|
)
|
||||||
|
self.is_multimer = is_multimer
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
m: torch.Tensor,
|
||||||
|
z: torch.Tensor,
|
||||||
|
chunk_size: Optional[int] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
m = m + self.msa_dropout_layer(
|
||||||
|
self.msa_att_row(m, z=z, chunk_size=chunk_size)
|
||||||
|
)
|
||||||
|
m = m + self.msa_att_col(m, chunk_size=chunk_size)
|
||||||
|
m, z = self.core(
|
||||||
|
m,
|
||||||
|
z,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
return m, z
|
||||||
|
|
||||||
|
|
||||||
|
class EvoformerStack(nn.Module):
|
||||||
|
"""
|
||||||
|
Main Evoformer trunk.
|
||||||
|
|
||||||
|
Implements Algorithm 6.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
c_m: int,
|
||||||
|
c_z: int,
|
||||||
|
c_hidden_msa_att: int,
|
||||||
|
c_hidden_opm: int,
|
||||||
|
c_hidden_mul: int,
|
||||||
|
c_hidden_pair_att: int,
|
||||||
|
c_s: int,
|
||||||
|
no_heads_msa: int,
|
||||||
|
no_heads_pair: int,
|
||||||
|
no_blocks: int,
|
||||||
|
transition_n: int,
|
||||||
|
msa_dropout: float,
|
||||||
|
pair_dropout: float,
|
||||||
|
blocks_per_ckpt: int,
|
||||||
|
inf: float,
|
||||||
|
eps: float,
|
||||||
|
clear_cache_between_blocks: bool = False,
|
||||||
|
is_multimer: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
c_m:
|
||||||
|
MSA channel dimension
|
||||||
|
c_z:
|
||||||
|
Pair channel dimension
|
||||||
|
c_hidden_msa_att:
|
||||||
|
Hidden dimension in MSA attention
|
||||||
|
c_hidden_opm:
|
||||||
|
Hidden dimension in outer product mean module
|
||||||
|
c_hidden_mul:
|
||||||
|
Hidden dimension in multiplicative updates
|
||||||
|
c_hidden_pair_att:
|
||||||
|
Hidden dimension in triangular attention
|
||||||
|
c_s:
|
||||||
|
Channel dimension of the output "single" embedding
|
||||||
|
no_heads_msa:
|
||||||
|
Number of heads used for MSA attention
|
||||||
|
no_heads_pair:
|
||||||
|
Number of heads used for pair attention
|
||||||
|
no_blocks:
|
||||||
|
Number of Evoformer blocks in the stack
|
||||||
|
transition_n:
|
||||||
|
Factor by which to multiply c_m to obtain the MSATransition
|
||||||
|
hidden dimension
|
||||||
|
msa_dropout:
|
||||||
|
Dropout rate for MSA activations
|
||||||
|
pair_dropout:
|
||||||
|
Dropout used for pair activations
|
||||||
|
blocks_per_ckpt:
|
||||||
|
Number of Evoformer blocks in each activation checkpoint
|
||||||
|
clear_cache_between_blocks:
|
||||||
|
Whether to clear CUDA's GPU memory cache between blocks of the
|
||||||
|
stack. Slows down each block but can reduce fragmentation
|
||||||
|
"""
|
||||||
|
super(EvoformerStack, self).__init__()
|
||||||
|
|
||||||
|
self.blocks_per_ckpt = blocks_per_ckpt
|
||||||
|
self.clear_cache_between_blocks = clear_cache_between_blocks
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList()
|
||||||
|
|
||||||
|
for _ in range(no_blocks):
|
||||||
|
block = EvoformerBlock(
|
||||||
|
c_m=c_m,
|
||||||
|
c_z=c_z,
|
||||||
|
c_hidden_msa_att=c_hidden_msa_att,
|
||||||
|
c_hidden_opm=c_hidden_opm,
|
||||||
|
c_hidden_mul=c_hidden_mul,
|
||||||
|
c_hidden_pair_att=c_hidden_pair_att,
|
||||||
|
no_heads_msa=no_heads_msa,
|
||||||
|
no_heads_pair=no_heads_pair,
|
||||||
|
transition_n=transition_n,
|
||||||
|
msa_dropout=msa_dropout,
|
||||||
|
pair_dropout=pair_dropout,
|
||||||
|
inf=inf,
|
||||||
|
eps=eps,
|
||||||
|
is_multimer=is_multimer,
|
||||||
|
)
|
||||||
|
self.blocks.append(block)
|
||||||
|
|
||||||
|
self.linear = Linear(c_m, c_s)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
m: torch.Tensor,
|
||||||
|
z: torch.Tensor,
|
||||||
|
msa_mask: torch.Tensor,
|
||||||
|
pair_mask: torch.Tensor,
|
||||||
|
chunk_size: int,
|
||||||
|
_mask_trans: bool = True,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
m:
|
||||||
|
[*, N_seq, N_res, C_m] MSA embedding
|
||||||
|
z:
|
||||||
|
[*, N_res, N_res, C_z] pair embedding
|
||||||
|
msa_mask:
|
||||||
|
[*, N_seq, N_res] MSA mask
|
||||||
|
pair_mask:
|
||||||
|
[*, N_res, N_res] pair mask
|
||||||
|
Returns:
|
||||||
|
m:
|
||||||
|
[*, N_seq, N_res, C_m] MSA embedding
|
||||||
|
z:
|
||||||
|
[*, N_res, N_res, C_z] pair embedding
|
||||||
|
s:
|
||||||
|
[*, N_res, C_s] single embedding (or None if extra MSA stack)
|
||||||
|
"""
|
||||||
|
blocks = [
|
||||||
|
partial(
|
||||||
|
b,
|
||||||
|
msa_mask=msa_mask,
|
||||||
|
pair_mask=pair_mask,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
_mask_trans=_mask_trans,
|
||||||
|
)
|
||||||
|
for b in self.blocks
|
||||||
|
]
|
||||||
|
|
||||||
|
if(self.clear_cache_between_blocks):
|
||||||
|
def block_with_cache_clear(block, *args):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
return block(*args)
|
||||||
|
|
||||||
|
blocks = [partial(block_with_cache_clear, b) for b in blocks]
|
||||||
|
|
||||||
|
m, z = checkpoint_blocks(
|
||||||
|
blocks,
|
||||||
|
args=(m, z),
|
||||||
|
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
s = self.linear(m[..., 0, :, :])
|
||||||
|
|
||||||
|
return m, z, s
|
|
@ -0,0 +1,331 @@
|
||||||
|
# Copyright 2021 AlQuraishi Laboratory
|
||||||
|
# Copyright 2021 DeepMind Technologies Limited
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
|
||||||
|
from .primitives import (
|
||||||
|
Linear,
|
||||||
|
LayerNorm,
|
||||||
|
Attention,
|
||||||
|
GlobalAttention,
|
||||||
|
_attention_chunked_trainable,
|
||||||
|
)
|
||||||
|
from .checkpointing import get_checkpoint_fn
|
||||||
|
from .tensor_utils import (
|
||||||
|
chunk_layer,
|
||||||
|
permute_final_dims,
|
||||||
|
flatten_final_dims,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MSAAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
c_in,
|
||||||
|
c_hidden,
|
||||||
|
no_heads,
|
||||||
|
pair_bias=False,
|
||||||
|
c_z=None,
|
||||||
|
inf=1e9,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
c_in:
|
||||||
|
Input channel dimension
|
||||||
|
c_hidden:
|
||||||
|
Per-head hidden channel dimension
|
||||||
|
no_heads:
|
||||||
|
Number of attention heads
|
||||||
|
pair_bias:
|
||||||
|
Whether to use pair embedding bias
|
||||||
|
c_z:
|
||||||
|
Pair embedding channel dimension. Ignored unless pair_bias
|
||||||
|
is true
|
||||||
|
inf:
|
||||||
|
A large number to be used in computing the attention mask
|
||||||
|
"""
|
||||||
|
super(MSAAttention, self).__init__()
|
||||||
|
|
||||||
|
self.c_in = c_in
|
||||||
|
self.c_hidden = c_hidden
|
||||||
|
self.no_heads = no_heads
|
||||||
|
self.pair_bias = pair_bias
|
||||||
|
self.c_z = c_z
|
||||||
|
self.inf = inf
|
||||||
|
|
||||||
|
self.layer_norm_m = LayerNorm(self.c_in)
|
||||||
|
|
||||||
|
self.layer_norm_z = None
|
||||||
|
self.linear_z = None
|
||||||
|
if self.pair_bias:
|
||||||
|
self.layer_norm_z = LayerNorm(self.c_z)
|
||||||
|
self.linear_z = Linear(
|
||||||
|
self.c_z, self.no_heads, bias=False, init="normal"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mha = Attention(
|
||||||
|
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def _chunk(self,
|
||||||
|
m: torch.Tensor,
|
||||||
|
biases: List[torch.Tensor],
|
||||||
|
chunk_size: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return chunk_layer(
|
||||||
|
self.mha,
|
||||||
|
{"q_x": m, "kv_x": m, "biases": biases},
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
no_batch_dims=len(m.shape[:-2]),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prep_inputs(self,
|
||||||
|
m: torch.Tensor,
|
||||||
|
z: Optional[torch.Tensor],
|
||||||
|
mask: Optional[torch.Tensor]
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
# [*, N_seq, N_res, C_m]
|
||||||
|
m = self.layer_norm_m(m)
|
||||||
|
|
||||||
|
n_seq, n_res = m.shape[-3:-1]
|
||||||
|
if mask is None:
|
||||||
|
# [*, N_seq, N_res]
|
||||||
|
mask = m.new_ones(
|
||||||
|
m.shape[:-3] + (n_seq, n_res),
|
||||||
|
)
|
||||||
|
|
||||||
|
# [*, N_seq, 1, 1, N_res]
|
||||||
|
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
|
||||||
|
|
||||||
|
# This step simply returns a larger view of the bias, and does not
|
||||||
|
# consume additional memory.
|
||||||
|
# [*, N_seq, no_heads, N_res, N_res]
|
||||||
|
#bias = bias.expand(
|
||||||
|
# ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
|
||||||
|
#)
|
||||||
|
|
||||||
|
if (self.pair_bias and
|
||||||
|
z is not None and # For the
|
||||||
|
self.layer_norm_z is not None and # benefit of
|
||||||
|
self.linear_z is not None # TorchScript
|
||||||
|
):
|
||||||
|
# [*, N_res, N_res, C_z]
|
||||||
|
z = self.layer_norm_z(z)
|
||||||
|
|
||||||
|
# [*, N_res, N_res, no_heads]
|
||||||
|
z = self.linear_z(z)
|
||||||
|
|
||||||
|
# [*, 1, no_heads, N_res, N_res]
|
||||||
|
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
|
||||||
|
|
||||||
|
return m, mask_bias, z
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
m: torch.Tensor,
|
||||||
|
z: Optional[torch.Tensor] = None,
|
||||||
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
chunk_size: Optional[int] = None,
|
||||||
|
_chunk_logits: Optional[int] = None,
|
||||||
|
_checkpoint_chunks: Optional[bool] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
m:
|
||||||
|
[*, N_seq, N_res, C_m] MSA embedding
|
||||||
|
z:
|
||||||
|
[*, N_res, N_res, C_z] pair embedding. Required only if
|
||||||
|
pair_bias is True
|
||||||
|
mask:
|
||||||
|
[*, N_seq, N_res] MSA mask
|
||||||
|
chunk_size:
|
||||||
|
Size of chunks into which the inputs are split along their
|
||||||
|
batch dimensions. A low value decreases memory overhead at the
|
||||||
|
cost of slower execution. Chunking is not performed by default.
|
||||||
|
|
||||||
|
"""
|
||||||
|
m, mask_bias, z = self._prep_inputs(m, z, mask)
|
||||||
|
|
||||||
|
biases = [mask_bias]
|
||||||
|
if(z is not None):
|
||||||
|
biases.append(z)
|
||||||
|
|
||||||
|
if chunk_size is not None:
|
||||||
|
m = self._chunk(m, biases, chunk_size)
|
||||||
|
else:
|
||||||
|
m = self.mha(
|
||||||
|
q_x=m,
|
||||||
|
kv_x=m,
|
||||||
|
biases=biases
|
||||||
|
)
|
||||||
|
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
class MSARowAttentionWithPairBias(MSAAttention):
|
||||||
|
"""
|
||||||
|
Implements Algorithm 7.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
c_m:
|
||||||
|
Input channel dimension
|
||||||
|
c_z:
|
||||||
|
Pair embedding channel dimension
|
||||||
|
c_hidden:
|
||||||
|
Per-head hidden channel dimension
|
||||||
|
no_heads:
|
||||||
|
Number of attention heads
|
||||||
|
inf:
|
||||||
|
Large number used to construct attention masks
|
||||||
|
"""
|
||||||
|
super(MSARowAttentionWithPairBias, self).__init__(
|
||||||
|
c_m,
|
||||||
|
c_hidden,
|
||||||
|
no_heads,
|
||||||
|
pair_bias=True,
|
||||||
|
c_z=c_z,
|
||||||
|
inf=inf,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MSAColumnAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
Implements Algorithm 8.
|
||||||
|
|
||||||
|
By rights, this should also be a subclass of MSAAttention. Alas,
|
||||||
|
most inheritance isn't supported by TorchScript.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, c_m, c_hidden, no_heads, inf=1e9):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
c_m:
|
||||||
|
MSA channel dimension
|
||||||
|
c_hidden:
|
||||||
|
Per-head hidden channel dimension
|
||||||
|
no_heads:
|
||||||
|
Number of attention heads
|
||||||
|
inf:
|
||||||
|
Large number used to construct attention masks
|
||||||
|
"""
|
||||||
|
super(MSAColumnAttention, self).__init__()
|
||||||
|
|
||||||
|
self.c_m = c_m
|
||||||
|
self.c_hidden = c_hidden
|
||||||
|
self.no_heads = no_heads
|
||||||
|
self.inf = inf
|
||||||
|
|
||||||
|
self._msa_att = MSAAttention(
|
||||||
|
c_in=c_m,
|
||||||
|
c_hidden=c_hidden,
|
||||||
|
no_heads=no_heads,
|
||||||
|
pair_bias=False,
|
||||||
|
c_z=None,
|
||||||
|
inf=inf,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
m: torch.Tensor,
|
||||||
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
chunk_size: Optional[int] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
m:
|
||||||
|
[*, N_seq, N_res, C_m] MSA embedding
|
||||||
|
mask:
|
||||||
|
[*, N_seq, N_res] MSA mask
|
||||||
|
chunk_size:
|
||||||
|
Size of chunks into which the inputs are split along their
|
||||||
|
batch dimensions. A low value decreases memory overhead at the
|
||||||
|
cost of slower execution. Chunking is not performed by default.
|
||||||
|
"""
|
||||||
|
# [*, N_res, N_seq, C_in]
|
||||||
|
m = m.transpose(-2, -3)
|
||||||
|
|
||||||
|
m = self._msa_att(m, chunk_size=chunk_size)
|
||||||
|
|
||||||
|
# [*, N_seq, N_res, C_in]
|
||||||
|
m = m.transpose(-2, -3)
|
||||||
|
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
class MSAColumnGlobalAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, c_in, c_hidden, no_heads, inf=1e9, eps=1e-10,
|
||||||
|
):
|
||||||
|
super(MSAColumnGlobalAttention, self).__init__()
|
||||||
|
|
||||||
|
self.c_in = c_in
|
||||||
|
self.c_hidden = c_hidden
|
||||||
|
self.no_heads = no_heads
|
||||||
|
self.inf = inf
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
self.layer_norm_m = nn.LayerNorm(c_in)
|
||||||
|
|
||||||
|
self.global_attention = GlobalAttention(
|
||||||
|
c_in=c_in,
|
||||||
|
c_hidden=c_hidden,
|
||||||
|
no_heads=no_heads,
|
||||||
|
inf=inf,
|
||||||
|
eps=eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def _chunk(self,
|
||||||
|
m: torch.Tensor,
|
||||||
|
chunk_size: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
mha_input = {
|
||||||
|
"m": m,
|
||||||
|
}
|
||||||
|
return chunk_layer(
|
||||||
|
self.global_attention,
|
||||||
|
mha_input,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
no_batch_dims=len(m.shape[:-2]),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
m: torch.Tensor,
|
||||||
|
chunk_size: Optional[int] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
n_seq, n_res, c_in = m.shape[-3:]
|
||||||
|
|
||||||
|
# [*, N_res, N_seq, C_in]
|
||||||
|
m = m.transpose(-2, -3)
|
||||||
|
|
||||||
|
# [*, N_res, N_seq, C_in]
|
||||||
|
m = self.layer_norm_m(m)
|
||||||
|
|
||||||
|
if chunk_size is not None:
|
||||||
|
m = self._chunk(m, chunk_size)
|
||||||
|
else:
|
||||||
|
m = self.global_attention(m=m)
|
||||||
|
|
||||||
|
# [*, N_seq, N_res, C_in]
|
||||||
|
m = m.transpose(-2, -3)
|
||||||
|
|
||||||
|
return m
|
|
@ -0,0 +1,129 @@
|
||||||
|
# Copyright 2021 AlQuraishi Laboratory
|
||||||
|
# Copyright 2021 DeepMind Technologies Limited
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .primitives import Linear
|
||||||
|
from .tensor_utils import chunk_layer
|
||||||
|
|
||||||
|
|
||||||
|
class OuterProductMean(nn.Module):
|
||||||
|
"""
|
||||||
|
Implements Algorithm 10.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, c_m, c_z, c_hidden, eps=1e-3):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
c_m:
|
||||||
|
MSA embedding channel dimension
|
||||||
|
c_z:
|
||||||
|
Pair embedding channel dimension
|
||||||
|
c_hidden:
|
||||||
|
Hidden channel dimension
|
||||||
|
"""
|
||||||
|
super(OuterProductMean, self).__init__()
|
||||||
|
|
||||||
|
self.c_m = c_m
|
||||||
|
self.c_z = c_z
|
||||||
|
self.c_hidden = c_hidden
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
self.layer_norm = nn.LayerNorm(c_m)
|
||||||
|
self.linear_1 = Linear(c_m, c_hidden)
|
||||||
|
self.linear_2 = Linear(c_m, c_hidden)
|
||||||
|
self.linear_out = Linear(c_hidden ** 2, c_z, init="final")
|
||||||
|
|
||||||
|
def _opm(self, a, b):
|
||||||
|
# [*, N_res, N_res, C, C]
|
||||||
|
outer = torch.einsum("...bac,...dae->...bdce", a, b)
|
||||||
|
|
||||||
|
# [*, N_res, N_res, C * C]
|
||||||
|
outer = outer.reshape(outer.shape[:-2] + (-1,))
|
||||||
|
|
||||||
|
# [*, N_res, N_res, C_z]
|
||||||
|
outer = self.linear_out(outer)
|
||||||
|
|
||||||
|
return outer
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def _chunk(self,
|
||||||
|
a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
chunk_size: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
# Since the "batch dim" in this case is not a true batch dimension
|
||||||
|
# (in that the shape of the output depends on it), we need to
|
||||||
|
# iterate over it ourselves
|
||||||
|
a_reshape = a.reshape((-1,) + a.shape[-3:])
|
||||||
|
b_reshape = b.reshape((-1,) + b.shape[-3:])
|
||||||
|
out = []
|
||||||
|
for a_prime, b_prime in zip(a_reshape, b_reshape):
|
||||||
|
outer = chunk_layer(
|
||||||
|
partial(self._opm, b=b_prime),
|
||||||
|
{"a": a_prime},
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
no_batch_dims=1,
|
||||||
|
)
|
||||||
|
out.append(outer)
|
||||||
|
outer = torch.stack(out, dim=0)
|
||||||
|
outer = outer.reshape(a.shape[:-3] + outer.shape[1:])
|
||||||
|
|
||||||
|
return outer
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
m: torch.Tensor,
|
||||||
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
chunk_size: Optional[int] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
m:
|
||||||
|
[*, N_seq, N_res, C_m] MSA embedding
|
||||||
|
mask:
|
||||||
|
[*, N_seq, N_res] MSA mask
|
||||||
|
Returns:
|
||||||
|
[*, N_res, N_res, C_z] pair embedding update
|
||||||
|
"""
|
||||||
|
if mask is None:
|
||||||
|
mask = m.new_ones(m.shape[:-1])
|
||||||
|
|
||||||
|
# [*, N_seq, N_res, C_m]
|
||||||
|
m = self.layer_norm(m)
|
||||||
|
|
||||||
|
# [*, N_seq, N_res, C]
|
||||||
|
mask = mask.unsqueeze(-1)
|
||||||
|
a = self.linear_1(m) * mask
|
||||||
|
b = self.linear_2(m) * mask
|
||||||
|
|
||||||
|
a = a.transpose(-2, -3)
|
||||||
|
b = b.transpose(-2, -3)
|
||||||
|
|
||||||
|
if chunk_size is not None:
|
||||||
|
outer = self._chunk(a, b, chunk_size)
|
||||||
|
else:
|
||||||
|
outer = self._opm(a, b)
|
||||||
|
|
||||||
|
# [*, N_res, N_res, 1]
|
||||||
|
norm = torch.einsum("...abc,...adc->...bdc", mask, mask)
|
||||||
|
|
||||||
|
# [*, N_res, N_res, C_z]
|
||||||
|
outer = outer / (self.eps + norm)
|
||||||
|
|
||||||
|
return outer
|
|
@ -0,0 +1,99 @@
|
||||||
|
# Copyright 2021 AlQuraishi Laboratory
|
||||||
|
# Copyright 2021 DeepMind Technologies Limited
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .primitives import Linear, LayerNorm
|
||||||
|
from .tensor_utils import chunk_layer
|
||||||
|
|
||||||
|
|
||||||
|
class PairTransition(nn.Module):
|
||||||
|
"""
|
||||||
|
Implements Algorithm 15.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, c_z, n):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
c_z:
|
||||||
|
Pair transition channel dimension
|
||||||
|
n:
|
||||||
|
Factor by which c_z is multiplied to obtain hidden channel
|
||||||
|
dimension
|
||||||
|
"""
|
||||||
|
super(PairTransition, self).__init__()
|
||||||
|
|
||||||
|
self.c_z = c_z
|
||||||
|
self.n = n
|
||||||
|
|
||||||
|
self.layer_norm = LayerNorm(self.c_z)
|
||||||
|
self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu")
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
|
||||||
|
|
||||||
|
def _transition(self, z, mask):
|
||||||
|
# [*, N_res, N_res, C_hidden]
|
||||||
|
z = self.linear_1(z)
|
||||||
|
z = self.relu(z)
|
||||||
|
|
||||||
|
# [*, N_res, N_res, C_z]
|
||||||
|
z = self.linear_2(z) * mask
|
||||||
|
|
||||||
|
return z
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def _chunk(self,
|
||||||
|
z: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
chunk_size: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return chunk_layer(
|
||||||
|
self._transition,
|
||||||
|
{"z": z, "mask": mask},
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
no_batch_dims=len(z.shape[:-2]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
z: torch.Tensor,
|
||||||
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
chunk_size: Optional[int] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
z:
|
||||||
|
[*, N_res, N_res, C_z] pair embedding
|
||||||
|
Returns:
|
||||||
|
[*, N_res, N_res, C_z] pair embedding update
|
||||||
|
"""
|
||||||
|
# DISCREPANCY: DeepMind forgets to apply the mask in this module.
|
||||||
|
if mask is None:
|
||||||
|
mask = z.new_ones(z.shape[:-1])
|
||||||
|
|
||||||
|
# [*, N_res, N_res, 1]
|
||||||
|
mask = mask.unsqueeze(-1)
|
||||||
|
|
||||||
|
# [*, N_res, N_res, C_z]
|
||||||
|
z = self.layer_norm(z)
|
||||||
|
|
||||||
|
if chunk_size is not None:
|
||||||
|
z = self._chunk(z, mask, chunk_size)
|
||||||
|
else:
|
||||||
|
z = self._transition(z=z, mask=mask)
|
||||||
|
|
||||||
|
return z
|
|
@ -0,0 +1,529 @@
|
||||||
|
# Copyright 2021 AlQuraishi Laboratory
|
||||||
|
# Copyright 2021 DeepMind Technologies Limited
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
import math
|
||||||
|
from typing import Optional, Callable, List, Tuple, Sequence
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .checkpointing import get_checkpoint_fn
|
||||||
|
from .tensor_utils import (
|
||||||
|
permute_final_dims,
|
||||||
|
flatten_final_dims,
|
||||||
|
_chunk_slice,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _prod(nums):
|
||||||
|
out = 1
|
||||||
|
for n in nums:
|
||||||
|
out = out * n
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_fan(linear_weight_shape, fan="fan_in"):
|
||||||
|
fan_out, fan_in = linear_weight_shape
|
||||||
|
|
||||||
|
if fan == "fan_in":
|
||||||
|
f = fan_in
|
||||||
|
elif fan == "fan_out":
|
||||||
|
f = fan_out
|
||||||
|
elif fan == "fan_avg":
|
||||||
|
f = (fan_in + fan_out) / 2
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid fan option")
|
||||||
|
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
def glorot_uniform_init_(weights):
|
||||||
|
nn.init.xavier_uniform_(weights, gain=1)
|
||||||
|
|
||||||
|
|
||||||
|
def final_init_(weights):
|
||||||
|
with torch.no_grad():
|
||||||
|
weights.fill_(0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def gating_init_(weights):
|
||||||
|
with torch.no_grad():
|
||||||
|
weights.fill_(0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def normal_init_(weights):
|
||||||
|
torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
|
||||||
|
|
||||||
|
|
||||||
|
def ipa_point_weights_init_(weights):
|
||||||
|
with torch.no_grad():
|
||||||
|
softplus_inverse_1 = 0.541324854612918
|
||||||
|
weights.fill_(softplus_inverse_1)
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(nn.Linear):
|
||||||
|
"""
|
||||||
|
A Linear layer with built-in nonstandard initializations. Called just
|
||||||
|
like torch.nn.Linear.
|
||||||
|
|
||||||
|
Implements the initializers in 1.11.4, plus some additional ones found
|
||||||
|
in the code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_dim: int,
|
||||||
|
out_dim: int,
|
||||||
|
bias: bool = True,
|
||||||
|
init: str = "default",
|
||||||
|
init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
in_dim:
|
||||||
|
The final dimension of inputs to the layer
|
||||||
|
out_dim:
|
||||||
|
The final dimension of layer outputs
|
||||||
|
bias:
|
||||||
|
Whether to learn an additive bias. True by default
|
||||||
|
init:
|
||||||
|
The initializer to use. Choose from:
|
||||||
|
|
||||||
|
"default": LeCun fan-in truncated normal initialization
|
||||||
|
"relu": He initialization w/ truncated normal distribution
|
||||||
|
"glorot": Fan-average Glorot uniform initialization
|
||||||
|
"gating": Weights=0, Bias=1
|
||||||
|
"normal": Normal initialization with std=1/sqrt(fan_in)
|
||||||
|
"final": Weights=0, Bias=0
|
||||||
|
|
||||||
|
Overridden by init_fn if the latter is not None.
|
||||||
|
init_fn:
|
||||||
|
A custom initializer taking weight and bias as inputs.
|
||||||
|
Overrides init if not None.
|
||||||
|
"""
|
||||||
|
super(Linear, self).__init__(in_dim, out_dim, bias=bias)
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
with torch.no_grad():
|
||||||
|
self.bias.fill_(0)
|
||||||
|
|
||||||
|
if init_fn is not None:
|
||||||
|
init_fn(self.weight, self.bias)
|
||||||
|
else:
|
||||||
|
if init == "default":
|
||||||
|
normal_init_(self.weight)
|
||||||
|
elif init == "relu":
|
||||||
|
normal_init_(self.weight)
|
||||||
|
elif init == "glorot":
|
||||||
|
glorot_uniform_init_(self.weight)
|
||||||
|
elif init == "gating":
|
||||||
|
gating_init_(self.weight)
|
||||||
|
if bias:
|
||||||
|
with torch.no_grad():
|
||||||
|
self.bias.fill_(1.0)
|
||||||
|
elif init == "normal":
|
||||||
|
normal_init_(self.weight)
|
||||||
|
elif init == "final":
|
||||||
|
final_init_(self.weight)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid init string.")
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, c_in, eps=1e-5):
|
||||||
|
super(LayerNorm, self).__init__()
|
||||||
|
|
||||||
|
self.c_in = (c_in,)
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
self.weight = nn.Parameter(torch.ones(c_in))
|
||||||
|
self.bias = nn.Parameter(torch.zeros(c_in))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = nn.functional.layer_norm(
|
||||||
|
x,
|
||||||
|
self.c_in,
|
||||||
|
self.weight,
|
||||||
|
self.bias,
|
||||||
|
self.eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Softmax, but without automatic casting to fp32 when the input is of
|
||||||
|
type bfloat16
|
||||||
|
"""
|
||||||
|
s = torch.nn.functional.softmax(t, dim=dim)
|
||||||
|
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
#@torch.jit.script
|
||||||
|
def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
|
||||||
|
biases: List[torch.Tensor]) -> torch.Tensor:
|
||||||
|
# [*, H, Q, C_hidden]
|
||||||
|
query = permute_final_dims(query, (1, 0, 2))
|
||||||
|
|
||||||
|
# [*, H, C_hidden, K]
|
||||||
|
key = permute_final_dims(key, (1, 2, 0))
|
||||||
|
|
||||||
|
# [*, H, V, C_hidden]
|
||||||
|
value = permute_final_dims(value, (1, 0, 2))
|
||||||
|
|
||||||
|
# [*, H, Q, K]
|
||||||
|
a = torch.matmul(query, key)
|
||||||
|
|
||||||
|
for b in biases:
|
||||||
|
a += b
|
||||||
|
|
||||||
|
a = softmax(a, -1)
|
||||||
|
|
||||||
|
# [*, H, Q, C_hidden]
|
||||||
|
a = torch.matmul(a, value)
|
||||||
|
|
||||||
|
# [*, Q, H, C_hidden]
|
||||||
|
a = a.transpose(-2, -3)
|
||||||
|
|
||||||
|
return a
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def _attention_chunked_trainable(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
biases,
|
||||||
|
chunk_size,
|
||||||
|
chunk_dim,
|
||||||
|
checkpoint,
|
||||||
|
):
|
||||||
|
if (checkpoint and len(biases) > 2):
|
||||||
|
raise ValueError("Checkpointed version permits only permits two bias terms")
|
||||||
|
|
||||||
|
def _checkpointable_attention(q, k, v, b1, b2):
|
||||||
|
bs = [b for b in [b1, b2] if b is not None]
|
||||||
|
return _attention(q, k, v, bs)
|
||||||
|
|
||||||
|
o_chunks = []
|
||||||
|
checkpoint_fn = get_checkpoint_fn()
|
||||||
|
count = query.shape[chunk_dim]
|
||||||
|
for start in range(0, count, chunk_size):
|
||||||
|
end = start + chunk_size
|
||||||
|
idx = [slice(None)] * len(query.shape)
|
||||||
|
idx[chunk_dim] = slice(start, end)
|
||||||
|
idx_tup = tuple(idx)
|
||||||
|
q_chunk = query[idx_tup]
|
||||||
|
k_chunk = key[idx_tup]
|
||||||
|
v_chunk = value[idx_tup]
|
||||||
|
|
||||||
|
def _slice_bias(b):
|
||||||
|
idx[chunk_dim] = (slice(start, end) if b.shape[chunk_dim] != 1 else slice(None))
|
||||||
|
return b[tuple(idx)]
|
||||||
|
|
||||||
|
if (checkpoint):
|
||||||
|
bias_1_chunk, bias_2_chunk = [
|
||||||
|
_slice_bias(b) if b is not None else None for b in (biases + [None, None])[:2]
|
||||||
|
]
|
||||||
|
|
||||||
|
o_chunk = checkpoint_fn(_checkpointable_attention, q_chunk, k_chunk, v_chunk,
|
||||||
|
bias_1_chunk, bias_2_chunk)
|
||||||
|
else:
|
||||||
|
bias_chunks = [_slice_bias(b) for b in biases]
|
||||||
|
|
||||||
|
o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks)
|
||||||
|
|
||||||
|
o_chunks.append(o_chunk)
|
||||||
|
|
||||||
|
o = torch.cat(o_chunks, dim=chunk_dim)
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
|
"""
|
||||||
|
Standard multi-head attention using AlphaFold's default layer
|
||||||
|
initialization. Allows multiple bias vectors.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
c_q: int,
|
||||||
|
c_k: int,
|
||||||
|
c_v: int,
|
||||||
|
c_hidden: int,
|
||||||
|
no_heads: int,
|
||||||
|
gating: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
c_q:
|
||||||
|
Input dimension of query data
|
||||||
|
c_k:
|
||||||
|
Input dimension of key data
|
||||||
|
c_v:
|
||||||
|
Input dimension of value data
|
||||||
|
c_hidden:
|
||||||
|
Per-head hidden dimension
|
||||||
|
no_heads:
|
||||||
|
Number of attention heads
|
||||||
|
gating:
|
||||||
|
Whether the output should be gated using query data
|
||||||
|
"""
|
||||||
|
super(Attention, self).__init__()
|
||||||
|
|
||||||
|
self.c_q = c_q
|
||||||
|
self.c_k = c_k
|
||||||
|
self.c_v = c_v
|
||||||
|
self.c_hidden = c_hidden
|
||||||
|
self.no_heads = no_heads
|
||||||
|
self.gating = gating
|
||||||
|
|
||||||
|
# DISCREPANCY: c_hidden is not the per-head channel dimension, as
|
||||||
|
# stated in the supplement, but the overall channel dimension.
|
||||||
|
|
||||||
|
self.linear_q = Linear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot")
|
||||||
|
self.linear_k = Linear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot")
|
||||||
|
self.linear_v = Linear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot")
|
||||||
|
self.linear_o = Linear(self.c_hidden * self.no_heads, self.c_q, init="final")
|
||||||
|
|
||||||
|
self.linear_g = None
|
||||||
|
if self.gating:
|
||||||
|
self.linear_g = Linear(self.c_q, self.c_hidden * self.no_heads, init="gating")
|
||||||
|
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def _prep_qkv(self, q_x: torch.Tensor,
|
||||||
|
kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
# [*, Q/K/V, H * C_hidden]
|
||||||
|
q = self.linear_q(q_x)
|
||||||
|
k = self.linear_k(kv_x)
|
||||||
|
v = self.linear_v(kv_x)
|
||||||
|
|
||||||
|
# [*, Q/K, H, C_hidden]
|
||||||
|
q = q.view(q.shape[:-1] + (self.no_heads, -1))
|
||||||
|
k = k.view(k.shape[:-1] + (self.no_heads, -1))
|
||||||
|
v = v.view(v.shape[:-1] + (self.no_heads, -1))
|
||||||
|
|
||||||
|
q /= math.sqrt(self.c_hidden)
|
||||||
|
|
||||||
|
return q, k, v
|
||||||
|
|
||||||
|
def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor:
|
||||||
|
if (self.linear_g is not None):
|
||||||
|
g = self.sigmoid(self.linear_g(q_x))
|
||||||
|
|
||||||
|
# [*, Q, H, C_hidden]
|
||||||
|
g = g.view(g.shape[:-1] + (self.no_heads, -1))
|
||||||
|
o = o * g
|
||||||
|
|
||||||
|
# [*, Q, H * C_hidden]
|
||||||
|
o = flatten_final_dims(o, 2)
|
||||||
|
|
||||||
|
# [*, Q, C_q]
|
||||||
|
o = self.linear_o(o)
|
||||||
|
|
||||||
|
return o
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
q_x: torch.Tensor,
|
||||||
|
kv_x: torch.Tensor,
|
||||||
|
biases: Optional[List[torch.Tensor]] = None,
|
||||||
|
use_lma: bool = False,
|
||||||
|
q_chunk_size: Optional[int] = None,
|
||||||
|
kv_chunk_size: Optional[int] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
q_x:
|
||||||
|
[*, Q, C_q] query data
|
||||||
|
kv_x:
|
||||||
|
[*, K, C_k] key data
|
||||||
|
biases:
|
||||||
|
List of biases that broadcast to [*, H, Q, K]
|
||||||
|
use_lma:
|
||||||
|
Whether to use low-memory attention
|
||||||
|
q_chunk_size:
|
||||||
|
Query chunk size (for LMA)
|
||||||
|
kv_chunk_size:
|
||||||
|
Key/Value chunk size (for LMA)
|
||||||
|
Returns
|
||||||
|
[*, Q, C_q] attention update
|
||||||
|
"""
|
||||||
|
if (biases is None):
|
||||||
|
biases = []
|
||||||
|
if (use_lma and (q_chunk_size is None or kv_chunk_size is None)):
|
||||||
|
raise ValueError("If use_lma is specified, q_chunk_size and kv_chunk_size must "
|
||||||
|
"be provided")
|
||||||
|
|
||||||
|
q, k, v = self._prep_qkv(q_x, kv_x)
|
||||||
|
|
||||||
|
if (use_lma):
|
||||||
|
biases = [b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) for b in biases]
|
||||||
|
|
||||||
|
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
|
||||||
|
else:
|
||||||
|
o = _attention(q, k, v, biases)
|
||||||
|
|
||||||
|
o = self._wrap_up(o, q_x)
|
||||||
|
|
||||||
|
return o
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, c_in, c_hidden, no_heads, inf, eps):
|
||||||
|
super(GlobalAttention, self).__init__()
|
||||||
|
|
||||||
|
self.c_in = c_in
|
||||||
|
self.c_hidden = c_hidden
|
||||||
|
self.no_heads = no_heads
|
||||||
|
self.inf = inf
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
self.linear_q = Linear(c_in, c_hidden * no_heads, bias=False, init="glorot")
|
||||||
|
|
||||||
|
self.linear_k = Linear(
|
||||||
|
c_in,
|
||||||
|
c_hidden,
|
||||||
|
bias=False,
|
||||||
|
init="glorot",
|
||||||
|
)
|
||||||
|
self.linear_v = Linear(
|
||||||
|
c_in,
|
||||||
|
c_hidden,
|
||||||
|
bias=False,
|
||||||
|
init="glorot",
|
||||||
|
)
|
||||||
|
self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating")
|
||||||
|
self.linear_o = Linear(c_hidden * no_heads, c_in, init="final")
|
||||||
|
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||||
|
# [*, N_res, C_in]
|
||||||
|
q = torch.sum(m * mask.unsqueeze(-1),
|
||||||
|
dim=-2) / (torch.sum(mask, dim=-1)[..., None] + self.eps)
|
||||||
|
|
||||||
|
# [*, N_res, H * C_hidden]
|
||||||
|
q = self.linear_q(q)
|
||||||
|
q *= (self.c_hidden**(-0.5))
|
||||||
|
|
||||||
|
# [*, N_res, H, C_hidden]
|
||||||
|
q = q.view(q.shape[:-1] + (self.no_heads, -1))
|
||||||
|
|
||||||
|
# [*, N_res, N_seq, C_hidden]
|
||||||
|
k = self.linear_k(m)
|
||||||
|
v = self.linear_v(m)
|
||||||
|
|
||||||
|
# [*, N_res, H, N_seq]
|
||||||
|
a = torch.matmul(
|
||||||
|
q,
|
||||||
|
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
|
||||||
|
)
|
||||||
|
bias = (self.inf * (mask - 1))[..., :, None, :]
|
||||||
|
a += bias
|
||||||
|
a = softmax(a)
|
||||||
|
|
||||||
|
# [*, N_res, H, C_hidden]
|
||||||
|
o = torch.matmul(
|
||||||
|
a,
|
||||||
|
v,
|
||||||
|
)
|
||||||
|
|
||||||
|
# [*, N_res, N_seq, C_hidden]
|
||||||
|
g = self.sigmoid(self.linear_g(m))
|
||||||
|
|
||||||
|
# [*, N_res, N_seq, H, C_hidden]
|
||||||
|
g = g.view(g.shape[:-1] + (self.no_heads, -1))
|
||||||
|
|
||||||
|
# [*, N_res, N_seq, H, C_hidden]
|
||||||
|
o = o.unsqueeze(-3) * g
|
||||||
|
|
||||||
|
# [*, N_res, N_seq, H * C_hidden]
|
||||||
|
o = o.reshape(o.shape[:-2] + (-1,))
|
||||||
|
|
||||||
|
# [*, N_res, N_seq, C_in]
|
||||||
|
m = self.linear_o(o)
|
||||||
|
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def _lma(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
biases: List[torch.Tensor],
|
||||||
|
q_chunk_size: int,
|
||||||
|
kv_chunk_size: int,
|
||||||
|
):
|
||||||
|
no_q, no_kv = q.shape[-3], k.shape[-3]
|
||||||
|
|
||||||
|
# [*, Q, H, C_hidden]
|
||||||
|
o = q.new_zeros(q.shape)
|
||||||
|
for q_s in range(0, no_q, q_chunk_size):
|
||||||
|
q_chunk = q[..., q_s:q_s + q_chunk_size, :, :]
|
||||||
|
large_bias_chunks = [b[..., q_s:q_s + q_chunk_size, :] for b in biases]
|
||||||
|
|
||||||
|
maxes = []
|
||||||
|
weights = []
|
||||||
|
values = []
|
||||||
|
for kv_s in range(0, no_kv, kv_chunk_size):
|
||||||
|
k_chunk = k[..., kv_s:kv_s + kv_chunk_size, :, :]
|
||||||
|
v_chunk = v[..., kv_s:kv_s + kv_chunk_size, :, :]
|
||||||
|
small_bias_chunks = [b[..., kv_s:kv_s + kv_chunk_size] for b in large_bias_chunks]
|
||||||
|
|
||||||
|
a = torch.einsum(
|
||||||
|
"...qhd,...khd->...hqk",
|
||||||
|
q_chunk,
|
||||||
|
k_chunk,
|
||||||
|
)
|
||||||
|
|
||||||
|
for b in small_bias_chunks:
|
||||||
|
a += b
|
||||||
|
|
||||||
|
a = a.transpose(-2, -3)
|
||||||
|
|
||||||
|
max_a = torch.max(a, dim=-1, keepdim=True)[0]
|
||||||
|
exp_a = torch.exp(a - max_a)
|
||||||
|
exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a)
|
||||||
|
|
||||||
|
maxes.append(max_a.detach().squeeze(-1))
|
||||||
|
weights.append(torch.sum(exp_a, dim=-1))
|
||||||
|
values.append(exp_v)
|
||||||
|
|
||||||
|
chunk_max = torch.stack(maxes, dim=-3)
|
||||||
|
chunk_weights = torch.stack(weights, dim=-3)
|
||||||
|
chunk_values = torch.stack(values, dim=-4)
|
||||||
|
|
||||||
|
global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0]
|
||||||
|
max_diffs = torch.exp(chunk_max - global_max)
|
||||||
|
chunk_values *= max_diffs.unsqueeze(-1)
|
||||||
|
chunk_weights *= max_diffs
|
||||||
|
|
||||||
|
all_values = torch.sum(chunk_values, dim=-4)
|
||||||
|
all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4)
|
||||||
|
|
||||||
|
q_chunk_out = all_values / all_weights
|
||||||
|
|
||||||
|
o[..., q_s:q_s + q_chunk_size, :, :] = q_chunk_out
|
||||||
|
|
||||||
|
return o
|
|
@ -0,0 +1,408 @@
|
||||||
|
# Copyright 2021 AlQuraishi Laboratory
|
||||||
|
# Copyright 2021 DeepMind Technologies Limited
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional
|
||||||
|
|
||||||
|
|
||||||
|
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
|
||||||
|
zero_index = -1 * len(inds)
|
||||||
|
first_inds = list(range(len(tensor.shape[:zero_index])))
|
||||||
|
return tensor.permute(first_inds + [zero_index + i for i in inds])
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_final_dims(t: torch.Tensor, no_dims: int):
|
||||||
|
return t.reshape(t.shape[:-no_dims] + (-1,))
|
||||||
|
|
||||||
|
|
||||||
|
def masked_mean(mask, value, dim, eps=1e-4):
|
||||||
|
mask = mask.expand(*value.shape)
|
||||||
|
return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
|
||||||
|
|
||||||
|
|
||||||
|
def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64):
|
||||||
|
boundaries = torch.linspace(
|
||||||
|
min_bin, max_bin, no_bins - 1, device=pts.device
|
||||||
|
)
|
||||||
|
dists = torch.sqrt(
|
||||||
|
torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1)
|
||||||
|
)
|
||||||
|
return torch.bucketize(dists, boundaries)
|
||||||
|
|
||||||
|
|
||||||
|
def dict_multimap(fn, dicts):
|
||||||
|
first = dicts[0]
|
||||||
|
new_dict = {}
|
||||||
|
for k, v in first.items():
|
||||||
|
all_v = [d[k] for d in dicts]
|
||||||
|
if type(v) is dict:
|
||||||
|
new_dict[k] = dict_multimap(fn, all_v)
|
||||||
|
else:
|
||||||
|
new_dict[k] = fn(all_v)
|
||||||
|
|
||||||
|
return new_dict
|
||||||
|
|
||||||
|
|
||||||
|
def one_hot(x, v_bins):
|
||||||
|
reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
|
||||||
|
diffs = x[..., None] - reshaped_bins
|
||||||
|
am = torch.argmin(torch.abs(diffs), dim=-1)
|
||||||
|
return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
|
||||||
|
|
||||||
|
|
||||||
|
def batched_gather(data, inds, dim=0, no_batch_dims=0):
|
||||||
|
ranges = []
|
||||||
|
for i, s in enumerate(data.shape[:no_batch_dims]):
|
||||||
|
r = torch.arange(s)
|
||||||
|
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
|
||||||
|
ranges.append(r)
|
||||||
|
|
||||||
|
remaining_dims = [
|
||||||
|
slice(None) for _ in range(len(data.shape) - no_batch_dims)
|
||||||
|
]
|
||||||
|
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
|
||||||
|
ranges.extend(remaining_dims)
|
||||||
|
return data[ranges]
|
||||||
|
|
||||||
|
|
||||||
|
# With tree_map, a poor man's JAX tree_map
|
||||||
|
def dict_map(fn, dic, leaf_type):
|
||||||
|
new_dict = {}
|
||||||
|
for k, v in dic.items():
|
||||||
|
if type(v) is dict:
|
||||||
|
new_dict[k] = dict_map(fn, v, leaf_type)
|
||||||
|
else:
|
||||||
|
new_dict[k] = tree_map(fn, v, leaf_type)
|
||||||
|
|
||||||
|
return new_dict
|
||||||
|
|
||||||
|
|
||||||
|
def tree_map(fn, tree, leaf_type):
|
||||||
|
if isinstance(tree, dict):
|
||||||
|
return dict_map(fn, tree, leaf_type)
|
||||||
|
elif isinstance(tree, list):
|
||||||
|
return [tree_map(fn, x, leaf_type) for x in tree]
|
||||||
|
elif isinstance(tree, tuple):
|
||||||
|
return tuple([tree_map(fn, x, leaf_type) for x in tree])
|
||||||
|
elif isinstance(tree, leaf_type):
|
||||||
|
return fn(tree)
|
||||||
|
else:
|
||||||
|
print(type(tree))
|
||||||
|
raise ValueError("Not supported")
|
||||||
|
|
||||||
|
|
||||||
|
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
|
||||||
|
|
||||||
|
def _fetch_dims(tree):
|
||||||
|
shapes = []
|
||||||
|
tree_type = type(tree)
|
||||||
|
if tree_type is dict:
|
||||||
|
for v in tree.values():
|
||||||
|
shapes.extend(_fetch_dims(v))
|
||||||
|
elif tree_type is list or tree_type is tuple:
|
||||||
|
for t in tree:
|
||||||
|
shapes.extend(_fetch_dims(t))
|
||||||
|
elif tree_type is torch.Tensor:
|
||||||
|
shapes.append(tree.shape)
|
||||||
|
else:
|
||||||
|
raise ValueError("Not supported")
|
||||||
|
|
||||||
|
return shapes
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def _flat_idx_to_idx(
|
||||||
|
flat_idx: int,
|
||||||
|
dims: Tuple[int],
|
||||||
|
) -> Tuple[int]:
|
||||||
|
idx = []
|
||||||
|
for d in reversed(dims):
|
||||||
|
idx.append(flat_idx % d)
|
||||||
|
flat_idx = flat_idx // d
|
||||||
|
|
||||||
|
return tuple(reversed(idx))
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def _get_minimal_slice_set(
|
||||||
|
start: Sequence[int],
|
||||||
|
end: Sequence[int],
|
||||||
|
dims: int,
|
||||||
|
start_edges: Optional[Sequence[bool]] = None,
|
||||||
|
end_edges: Optional[Sequence[bool]] = None,
|
||||||
|
) -> Sequence[Tuple[int]]:
|
||||||
|
"""
|
||||||
|
Produces an ordered sequence of tensor slices that, when used in
|
||||||
|
sequence on a tensor with shape dims, yields tensors that contain every
|
||||||
|
leaf in the contiguous range [start, end]. Care is taken to yield a
|
||||||
|
short sequence of slices, and perhaps even the shortest possible (I'm
|
||||||
|
pretty sure it's the latter).
|
||||||
|
|
||||||
|
end is INCLUSIVE.
|
||||||
|
"""
|
||||||
|
# start_edges and end_edges both indicate whether, starting from any given
|
||||||
|
# dimension, the start/end index is at the top/bottom edge of the
|
||||||
|
# corresponding tensor, modeled as a tree
|
||||||
|
def reduce_edge_list(ll):
|
||||||
|
tally = 1
|
||||||
|
for i in range(len(ll)):
|
||||||
|
reversed_idx = -1 * (i + 1)
|
||||||
|
ll[reversed_idx] *= tally
|
||||||
|
tally = ll[reversed_idx]
|
||||||
|
|
||||||
|
if(start_edges is None):
|
||||||
|
start_edges = [s == 0 for s in start]
|
||||||
|
reduce_edge_list(start_edges)
|
||||||
|
if(end_edges is None):
|
||||||
|
end_edges = [e == (d - 1) for e,d in zip(end, dims)]
|
||||||
|
reduce_edge_list(end_edges)
|
||||||
|
|
||||||
|
# Base cases. Either start/end are empty and we're done, or the final,
|
||||||
|
# one-dimensional tensor can be simply sliced
|
||||||
|
if(len(start) == 0):
|
||||||
|
return [tuple()]
|
||||||
|
elif(len(start) == 1):
|
||||||
|
return [(slice(start[0], end[0] + 1),)]
|
||||||
|
|
||||||
|
slices = []
|
||||||
|
path = []
|
||||||
|
|
||||||
|
# Dimensions common to start and end can be selected directly
|
||||||
|
for s,e in zip(start, end):
|
||||||
|
if(s == e):
|
||||||
|
path.append(slice(s, s + 1))
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
path = tuple(path)
|
||||||
|
divergence_idx = len(path)
|
||||||
|
|
||||||
|
# start == end, and we're done
|
||||||
|
if(divergence_idx == len(dims)):
|
||||||
|
return [tuple(path)]
|
||||||
|
|
||||||
|
def upper():
|
||||||
|
sdi = start[divergence_idx]
|
||||||
|
return [
|
||||||
|
path + (slice(sdi, sdi + 1),) + s for s in
|
||||||
|
_get_minimal_slice_set(
|
||||||
|
start[divergence_idx + 1:],
|
||||||
|
[d - 1 for d in dims[divergence_idx + 1:]],
|
||||||
|
dims[divergence_idx + 1:],
|
||||||
|
start_edges=start_edges[divergence_idx + 1:],
|
||||||
|
end_edges=[1 for _ in end_edges[divergence_idx + 1:]]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
def lower():
|
||||||
|
edi = end[divergence_idx]
|
||||||
|
return [
|
||||||
|
path + (slice(edi, edi + 1),) + s for s in
|
||||||
|
_get_minimal_slice_set(
|
||||||
|
[0 for _ in start[divergence_idx + 1:]],
|
||||||
|
end[divergence_idx + 1:],
|
||||||
|
dims[divergence_idx + 1:],
|
||||||
|
start_edges=[1 for _ in start_edges[divergence_idx + 1:]],
|
||||||
|
end_edges=end_edges[divergence_idx + 1:],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# If both start and end are at the edges of the subtree rooted at
|
||||||
|
# divergence_idx, we can just select the whole subtree at once
|
||||||
|
if(start_edges[divergence_idx] and end_edges[divergence_idx]):
|
||||||
|
slices.append(
|
||||||
|
path + (slice(start[divergence_idx], end[divergence_idx] + 1),)
|
||||||
|
)
|
||||||
|
# If just start is at the edge, we can grab almost all of the subtree,
|
||||||
|
# treating only the ragged bottom edge as an edge case
|
||||||
|
elif(start_edges[divergence_idx]):
|
||||||
|
slices.append(
|
||||||
|
path + (slice(start[divergence_idx], end[divergence_idx]),)
|
||||||
|
)
|
||||||
|
slices.extend(lower())
|
||||||
|
# Analogous to the previous case, but the top is ragged this time
|
||||||
|
elif(end_edges[divergence_idx]):
|
||||||
|
slices.extend(upper())
|
||||||
|
slices.append(
|
||||||
|
path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),)
|
||||||
|
)
|
||||||
|
# If both sides of the range are ragged, we need to handle both sides
|
||||||
|
# separately. If there's contiguous meat in between them, we can index it
|
||||||
|
# in one big chunk
|
||||||
|
else:
|
||||||
|
slices.extend(upper())
|
||||||
|
middle_ground = end[divergence_idx] - start[divergence_idx]
|
||||||
|
if(middle_ground > 1):
|
||||||
|
slices.append(
|
||||||
|
path + (slice(start[divergence_idx] + 1, end[divergence_idx]),)
|
||||||
|
)
|
||||||
|
slices.extend(lower())
|
||||||
|
|
||||||
|
return [tuple(s) for s in slices]
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def _chunk_slice(
|
||||||
|
t: torch.Tensor,
|
||||||
|
flat_start: int,
|
||||||
|
flat_end: int,
|
||||||
|
no_batch_dims: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Equivalent to
|
||||||
|
|
||||||
|
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
|
||||||
|
|
||||||
|
but without the need for the initial reshape call, which can be
|
||||||
|
memory-intensive in certain situations. The only reshape operations
|
||||||
|
in this function are performed on sub-tensors that scale with
|
||||||
|
(flat_end - flat_start), the chunk size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch_dims = t.shape[:no_batch_dims]
|
||||||
|
start_idx = list(_flat_idx_to_idx(flat_start, batch_dims))
|
||||||
|
# _get_minimal_slice_set is inclusive
|
||||||
|
end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims))
|
||||||
|
|
||||||
|
# Get an ordered list of slices to perform
|
||||||
|
slices = _get_minimal_slice_set(
|
||||||
|
start_idx,
|
||||||
|
end_idx,
|
||||||
|
batch_dims,
|
||||||
|
)
|
||||||
|
|
||||||
|
sliced_tensors = [t[s] for s in slices]
|
||||||
|
|
||||||
|
return torch.cat(
|
||||||
|
[s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_layer(
|
||||||
|
layer: Callable,
|
||||||
|
inputs: Dict[str, Any],
|
||||||
|
chunk_size: int,
|
||||||
|
no_batch_dims: int,
|
||||||
|
low_mem: bool = False,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Implements the "chunking" procedure described in section 1.11.8.
|
||||||
|
|
||||||
|
Layer outputs and inputs are assumed to be simple "pytrees,"
|
||||||
|
consisting only of (arbitrarily nested) lists, tuples, and dicts with
|
||||||
|
torch.Tensor leaves.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer:
|
||||||
|
The layer to be applied chunk-wise
|
||||||
|
inputs:
|
||||||
|
A (non-nested) dictionary of keyworded inputs. All leaves must
|
||||||
|
be tensors and must share the same batch dimensions.
|
||||||
|
chunk_size:
|
||||||
|
The number of sub-batches per chunk. If multiple batch
|
||||||
|
dimensions are specified, a "sub-batch" is defined as a single
|
||||||
|
indexing of all batch dimensions simultaneously (s.t. the
|
||||||
|
number of sub-batches is the product of the batch dimensions).
|
||||||
|
no_batch_dims:
|
||||||
|
How many of the initial dimensions of each input tensor can
|
||||||
|
be considered batch dimensions.
|
||||||
|
low_mem:
|
||||||
|
Avoids flattening potentially large input tensors. Unnecessary
|
||||||
|
in most cases, and is ever so slightly slower than the default
|
||||||
|
setting.
|
||||||
|
Returns:
|
||||||
|
The reassembled output of the layer on the inputs.
|
||||||
|
"""
|
||||||
|
if not (len(inputs) > 0):
|
||||||
|
raise ValueError("Must provide at least one input")
|
||||||
|
|
||||||
|
initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
|
||||||
|
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
|
||||||
|
|
||||||
|
def _prep_inputs(t):
|
||||||
|
# TODO: make this more memory efficient. This sucks
|
||||||
|
if(not low_mem):
|
||||||
|
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
|
||||||
|
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
|
||||||
|
t = t.reshape(-1, *t.shape[no_batch_dims:])
|
||||||
|
else:
|
||||||
|
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
|
||||||
|
return t
|
||||||
|
|
||||||
|
prepped_inputs = tensor_tree_map(_prep_inputs, inputs)
|
||||||
|
|
||||||
|
flat_batch_dim = 1
|
||||||
|
for d in orig_batch_dims:
|
||||||
|
flat_batch_dim *= d
|
||||||
|
|
||||||
|
no_chunks = flat_batch_dim // chunk_size + (
|
||||||
|
flat_batch_dim % chunk_size != 0
|
||||||
|
)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
out = None
|
||||||
|
for _ in range(no_chunks):
|
||||||
|
# Chunk the input
|
||||||
|
if(not low_mem):
|
||||||
|
select_chunk = (
|
||||||
|
lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
select_chunk = (
|
||||||
|
partial(
|
||||||
|
_chunk_slice,
|
||||||
|
flat_start=i,
|
||||||
|
flat_end=min(flat_batch_dim, i + chunk_size),
|
||||||
|
no_batch_dims=len(orig_batch_dims)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = tensor_tree_map(select_chunk, prepped_inputs)
|
||||||
|
|
||||||
|
# Run the layer on the chunk
|
||||||
|
output_chunk = layer(**chunks)
|
||||||
|
|
||||||
|
# Allocate space for the output
|
||||||
|
if out is None:
|
||||||
|
allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:])
|
||||||
|
out = tensor_tree_map(allocate, output_chunk)
|
||||||
|
|
||||||
|
# Put the chunk in its pre-allocated space
|
||||||
|
out_type = type(output_chunk)
|
||||||
|
if out_type is dict:
|
||||||
|
def assign(d1, d2):
|
||||||
|
for k, v in d1.items():
|
||||||
|
if type(v) is dict:
|
||||||
|
assign(v, d2[k])
|
||||||
|
else:
|
||||||
|
v[i : i + chunk_size] = d2[k]
|
||||||
|
|
||||||
|
assign(out, output_chunk)
|
||||||
|
elif out_type is tuple:
|
||||||
|
for x1, x2 in zip(out, output_chunk):
|
||||||
|
x1[i : i + chunk_size] = x2
|
||||||
|
elif out_type is torch.Tensor:
|
||||||
|
out[i : i + chunk_size] = output_chunk
|
||||||
|
else:
|
||||||
|
raise ValueError("Not supported")
|
||||||
|
|
||||||
|
i += chunk_size
|
||||||
|
|
||||||
|
reshape = lambda t: t.view(orig_batch_dims + t.shape[1:])
|
||||||
|
out = tensor_tree_map(reshape, out)
|
||||||
|
|
||||||
|
return out
|
|
@ -0,0 +1,139 @@
|
||||||
|
# Copyright 2021 AlQuraishi Laboratory
|
||||||
|
# Copyright 2021 DeepMind Technologies Limited
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from functools import partialmethod, partial
|
||||||
|
import math
|
||||||
|
from typing import Optional, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .primitives import Linear, LayerNorm, Attention
|
||||||
|
from .tensor_utils import (
|
||||||
|
chunk_layer,
|
||||||
|
permute_final_dims,
|
||||||
|
flatten_final_dims,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TriangleAttention(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, c_in, c_hidden, no_heads, starting, inf=1e9
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
c_in:
|
||||||
|
Input channel dimension
|
||||||
|
c_hidden:
|
||||||
|
Overall hidden channel dimension (not per-head)
|
||||||
|
no_heads:
|
||||||
|
Number of attention heads
|
||||||
|
"""
|
||||||
|
super(TriangleAttention, self).__init__()
|
||||||
|
|
||||||
|
self.c_in = c_in
|
||||||
|
self.c_hidden = c_hidden
|
||||||
|
self.no_heads = no_heads
|
||||||
|
self.starting = starting
|
||||||
|
self.inf = inf
|
||||||
|
|
||||||
|
self.layer_norm = LayerNorm(self.c_in)
|
||||||
|
|
||||||
|
self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
|
||||||
|
|
||||||
|
self.mha = Attention(
|
||||||
|
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
|
||||||
|
)
|
||||||
|
|
||||||
|
@torch.jit.ignore
|
||||||
|
def _chunk(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
biases: List[torch.Tensor],
|
||||||
|
chunk_size: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
mha_inputs = {
|
||||||
|
"q_x": x,
|
||||||
|
"kv_x": x,
|
||||||
|
"biases": biases,
|
||||||
|
}
|
||||||
|
return chunk_layer(
|
||||||
|
partial(self.mha),
|
||||||
|
mha_inputs,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
no_batch_dims=len(x.shape[:-2]),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
chunk_size: Optional[int] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
[*, I, J, C_in] input tensor (e.g. the pair representation)
|
||||||
|
Returns:
|
||||||
|
[*, I, J, C_in] output tensor
|
||||||
|
"""
|
||||||
|
if mask is None:
|
||||||
|
# [*, I, J]
|
||||||
|
mask = x.new_ones(
|
||||||
|
x.shape[:-1],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shape annotations assume self.starting. Else, I and J are flipped
|
||||||
|
if not self.starting:
|
||||||
|
x = x.transpose(-2, -3)
|
||||||
|
mask = mask.transpose(-1, -2)
|
||||||
|
|
||||||
|
# [*, I, J, C_in]
|
||||||
|
x = self.layer_norm(x)
|
||||||
|
|
||||||
|
# [*, I, 1, 1, J]
|
||||||
|
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
|
||||||
|
|
||||||
|
# [*, H, I, J]
|
||||||
|
triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
|
||||||
|
|
||||||
|
# [*, 1, H, I, J]
|
||||||
|
triangle_bias = triangle_bias.unsqueeze(-4)
|
||||||
|
|
||||||
|
biases = [mask_bias, triangle_bias]
|
||||||
|
|
||||||
|
if chunk_size is not None:
|
||||||
|
x = self._chunk(x, biases, chunk_size)
|
||||||
|
else:
|
||||||
|
x = self.mha(q_x=x, kv_x=x, biases=biases)
|
||||||
|
|
||||||
|
if not self.starting:
|
||||||
|
x = x.transpose(-2, -3)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class TriangleAttentionStartingNode(TriangleAttention):
|
||||||
|
"""
|
||||||
|
Implements Algorithm 13.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__init__ = partialmethod(TriangleAttention.__init__, starting=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TriangleAttentionEndingNode(TriangleAttention):
|
||||||
|
"""
|
||||||
|
Implements Algorithm 14.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__init__ = partialmethod(TriangleAttention.__init__, starting=False)
|
|
@ -0,0 +1,127 @@
|
||||||
|
# Copyright 2021 AlQuraishi Laboratory
|
||||||
|
# Copyright 2021 DeepMind Technologies Limited
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from functools import partialmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .primitives import Linear, LayerNorm
|
||||||
|
from .tensor_utils import permute_final_dims
|
||||||
|
|
||||||
|
|
||||||
|
class TriangleMultiplicativeUpdate(nn.Module):
|
||||||
|
"""
|
||||||
|
Implements Algorithms 11 and 12.
|
||||||
|
"""
|
||||||
|
def __init__(self, c_z, c_hidden, _outgoing=True):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
c_z:
|
||||||
|
Input channel dimension
|
||||||
|
c:
|
||||||
|
Hidden channel dimension
|
||||||
|
"""
|
||||||
|
super(TriangleMultiplicativeUpdate, self).__init__()
|
||||||
|
self.c_z = c_z
|
||||||
|
self.c_hidden = c_hidden
|
||||||
|
self._outgoing = _outgoing
|
||||||
|
|
||||||
|
self.linear_a_p = Linear(self.c_z, self.c_hidden)
|
||||||
|
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
|
||||||
|
self.linear_b_p = Linear(self.c_z, self.c_hidden)
|
||||||
|
self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
|
||||||
|
self.linear_g = Linear(self.c_z, self.c_z, init="gating")
|
||||||
|
self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
|
||||||
|
|
||||||
|
self.layer_norm_in = LayerNorm(self.c_z)
|
||||||
|
self.layer_norm_out = LayerNorm(self.c_hidden)
|
||||||
|
|
||||||
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
|
def _combine_projections(self,
|
||||||
|
a: torch.Tensor,
|
||||||
|
b: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
raise NotImplementedError("This method needs to be overridden")
|
||||||
|
|
||||||
|
def forward(self,
|
||||||
|
z: torch.Tensor,
|
||||||
|
mask: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
[*, N_res, N_res, C_z] input tensor
|
||||||
|
mask:
|
||||||
|
[*, N_res, N_res] input mask
|
||||||
|
Returns:
|
||||||
|
[*, N_res, N_res, C_z] output tensor
|
||||||
|
"""
|
||||||
|
if mask is None:
|
||||||
|
mask = z.new_ones(z.shape[:-1])
|
||||||
|
|
||||||
|
mask = mask.unsqueeze(-1)
|
||||||
|
|
||||||
|
z = self.layer_norm_in(z)
|
||||||
|
a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z))
|
||||||
|
a = a * mask
|
||||||
|
b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z))
|
||||||
|
b = b * mask
|
||||||
|
x = self._combine_projections(a, b)
|
||||||
|
x = self.layer_norm_out(x)
|
||||||
|
x = self.linear_z(x)
|
||||||
|
g = self.sigmoid(self.linear_g(z))
|
||||||
|
z = x * g
|
||||||
|
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
|
||||||
|
"""
|
||||||
|
Implements Algorithm 11.
|
||||||
|
"""
|
||||||
|
def _combine_projections(self,
|
||||||
|
a: torch.Tensor, # [*, N_i, N_k, C]
|
||||||
|
b: torch.Tensor, # [*, N_j, N_k, C]
|
||||||
|
):
|
||||||
|
# [*, C, N_i, N_j]
|
||||||
|
p = torch.matmul(
|
||||||
|
permute_final_dims(a, (2, 0, 1)),
|
||||||
|
permute_final_dims(b, (2, 1, 0)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# [*, N_i, N_j, C]
|
||||||
|
return permute_final_dims(p, (1, 2, 0))
|
||||||
|
|
||||||
|
|
||||||
|
class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
|
||||||
|
"""
|
||||||
|
Implements Algorithm 12.
|
||||||
|
"""
|
||||||
|
def _combine_projections(self,
|
||||||
|
a: torch.Tensor, # [*, N_k, N_i, C]
|
||||||
|
b: torch.Tensor, # [*, N_k, N_j, C]
|
||||||
|
):
|
||||||
|
# [*, C, N_i, N_j]
|
||||||
|
p = torch.matmul(
|
||||||
|
permute_final_dims(a, (2, 1, 0)),
|
||||||
|
permute_final_dims(b, (2, 0, 1)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# [*, N_i, N_j, C]
|
||||||
|
return permute_final_dims(p, (1, 2, 0))
|
||||||
|
|
|
@ -0,0 +1,113 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.fx
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.fx import ColoTracer
|
||||||
|
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||||
|
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||||
|
from colossalai.fx.graph_module import ColoGraphModule
|
||||||
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from tests.test_autochunk.evoformer.evoformer import evoformer_base
|
||||||
|
|
||||||
|
if CODEGEN_AVAILABLE and is_compatible_with_meta():
|
||||||
|
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||||
|
from colossalai.fx.profiler import MetaTensor
|
||||||
|
|
||||||
|
|
||||||
|
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
|
||||||
|
# for memory test
|
||||||
|
# torch.cuda.reset_peak_memory_stats()
|
||||||
|
# now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||||
|
# with torch.no_grad():
|
||||||
|
# node1 = node.clone()
|
||||||
|
# pair1 = pair.clone()
|
||||||
|
# gm(node1, pair1)
|
||||||
|
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||||
|
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||||
|
# print(
|
||||||
|
# "autochunk now mem:%.2f max mem:%.2f"
|
||||||
|
# % (new_now_mem - now_mem, new_max_mem - now_mem)
|
||||||
|
# )
|
||||||
|
|
||||||
|
# test forward
|
||||||
|
with torch.no_grad():
|
||||||
|
non_fx_out = model(node, pair)
|
||||||
|
fx_out = gm(node, pair)
|
||||||
|
|
||||||
|
assert torch.allclose(non_fx_out[0], fx_out[0],
|
||||||
|
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||||
|
torch.abs(non_fx_out[0] - fx_out[0]))
|
||||||
|
assert torch.allclose(non_fx_out[1], fx_out[1],
|
||||||
|
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||||
|
torch.abs(non_fx_out[1] - fx_out[1]))
|
||||||
|
|
||||||
|
|
||||||
|
def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
|
||||||
|
# launch colossalai
|
||||||
|
colossalai.launch(
|
||||||
|
config={},
|
||||||
|
rank=rank,
|
||||||
|
world_size=1,
|
||||||
|
host="localhost",
|
||||||
|
port=free_port(),
|
||||||
|
backend="nccl",
|
||||||
|
)
|
||||||
|
|
||||||
|
# build model and input
|
||||||
|
model = evoformer_base().cuda()
|
||||||
|
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||||
|
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||||
|
|
||||||
|
# trace the module and replace codegen
|
||||||
|
graph = ColoTracer().trace(
|
||||||
|
model,
|
||||||
|
meta_args={
|
||||||
|
"node": node.to(torch.device("meta")),
|
||||||
|
"pair": pair.to(torch.device("meta")),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
|
||||||
|
interp = MetaInfoProp(gm_prop)
|
||||||
|
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
|
||||||
|
|
||||||
|
# now run it twice to get meta info in graph module, not necessary
|
||||||
|
gm = torch.fx.GraphModule(model, graph)
|
||||||
|
interp = MetaInfoProp(gm)
|
||||||
|
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
|
||||||
|
|
||||||
|
codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory)
|
||||||
|
graph.set_codegen(codegen)
|
||||||
|
gm = ColoGraphModule(model, graph)
|
||||||
|
gm.recompile()
|
||||||
|
|
||||||
|
# assert we have inserted chunk
|
||||||
|
code = graph.python_code("self").src
|
||||||
|
assert "chunk_size" in code
|
||||||
|
# print(code)
|
||||||
|
|
||||||
|
_test_fwd(model, gm, node, pair)
|
||||||
|
gpc.destroy()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta()), reason='torch version is lower than 1.12.0')
|
||||||
|
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
|
||||||
|
@pytest.mark.parametrize("msa_len", [32])
|
||||||
|
@pytest.mark.parametrize("pair_len", [64])
|
||||||
|
def test_autochunk_codegen(msa_len, pair_len, max_memory):
|
||||||
|
run_func = partial(
|
||||||
|
_test_autochunk_codegen,
|
||||||
|
msa_len=msa_len,
|
||||||
|
pair_len=pair_len,
|
||||||
|
max_memory=max_memory,
|
||||||
|
)
|
||||||
|
mp.spawn(run_func, nprocs=1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
_test_autochunk_codegen(0, 32, 64, 25)
|
|
@ -0,0 +1,102 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.fx
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||||
|
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||||
|
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
from tests.test_autochunk.evoformer.evoformer import evoformer_base
|
||||||
|
|
||||||
|
if CODEGEN_AVAILABLE and is_compatible_with_meta():
|
||||||
|
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||||
|
from colossalai.fx.profiler import MetaTensor
|
||||||
|
|
||||||
|
|
||||||
|
def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
|
||||||
|
found_regions = [i["region"] for i in chunk_infos]
|
||||||
|
|
||||||
|
if msa_len == 32 and pair_len == 64:
|
||||||
|
if max_memory is None:
|
||||||
|
target_regions = [(142, 154), (366, 373), (233, 283), (301, 351), (127, 134), (204, 228), (167, 191),
|
||||||
|
(161, 166), (198, 203), (6, 69)]
|
||||||
|
elif max_memory == 20:
|
||||||
|
target_regions = [(142, 154), (369, 373), (233, 269), (301, 351)]
|
||||||
|
elif max_memory == 25:
|
||||||
|
target_regions = [(144, 154), (369, 370)]
|
||||||
|
elif max_memory == 30:
|
||||||
|
target_regions = [(144, 154)]
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
assert len(found_regions) == len(
|
||||||
|
target_regions), "len of found regions %s doesn't equal len of target regions %s" % (
|
||||||
|
str(found_regions),
|
||||||
|
str(target_regions),
|
||||||
|
)
|
||||||
|
for region in target_regions:
|
||||||
|
assert (region in found_regions), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % (
|
||||||
|
str(region),
|
||||||
|
msa_len,
|
||||||
|
pair_len,
|
||||||
|
max_memory,
|
||||||
|
)
|
||||||
|
for region in found_regions:
|
||||||
|
assert (region in target_regions), "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % (
|
||||||
|
str(region),
|
||||||
|
msa_len,
|
||||||
|
pair_len,
|
||||||
|
max_memory,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_autochunk_search(rank, msa_len, pair_len, max_memory):
|
||||||
|
# launch colossalai
|
||||||
|
colossalai.launch(
|
||||||
|
config={},
|
||||||
|
rank=rank,
|
||||||
|
world_size=1,
|
||||||
|
host="localhost",
|
||||||
|
port=free_port(),
|
||||||
|
backend="nccl",
|
||||||
|
)
|
||||||
|
|
||||||
|
# build model and input
|
||||||
|
model = evoformer_base().cuda()
|
||||||
|
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||||
|
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||||
|
|
||||||
|
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
|
||||||
|
interp = MetaInfoProp(gm_prop)
|
||||||
|
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
|
||||||
|
|
||||||
|
codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory)
|
||||||
|
chunk_infos = codegen.chunk_infos
|
||||||
|
assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len)
|
||||||
|
|
||||||
|
gpc.destroy()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta()), reason="torch version is lower than 1.12.0")
|
||||||
|
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
|
||||||
|
@pytest.mark.parametrize("msa_len", [32])
|
||||||
|
@pytest.mark.parametrize("pair_len", [64])
|
||||||
|
def test_autochunk_search(msa_len, pair_len, max_memory):
|
||||||
|
run_func = partial(
|
||||||
|
_test_autochunk_search,
|
||||||
|
msa_len=msa_len,
|
||||||
|
pair_len=pair_len,
|
||||||
|
max_memory=max_memory,
|
||||||
|
)
|
||||||
|
mp.spawn(run_func, nprocs=1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
_test_autochunk_search(0, 32, 64, 20)
|
Loading…
Reference in New Issue