mirror of https://github.com/hpcaitech/ColossalAI
[autochunk] support evoformer tracer (#2485)
support full evoformer tracer, which is a main module of alphafold. previously we just support a simplifed version of it. 1. support some evoformer's op in fx 2. support evoformer test 3. add repos for test codepull/2488/head
parent
67e1912b59
commit
4953b4ace1
|
@ -48,9 +48,7 @@ def _gen_chunk_slice_dim(chunk_dim: int, chunk_indice_name: str, shape: List) ->
|
|||
return new_shape
|
||||
|
||||
|
||||
def _gen_loop_start(
|
||||
chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim: int, chunk_size=2
|
||||
) -> str:
|
||||
def _gen_loop_start(chunk_input: List[Node], chunk_output: Node, chunk_ouput_dim: int, chunk_size=2) -> str:
|
||||
"""
|
||||
Generate chunk loop start
|
||||
|
||||
|
@ -72,9 +70,8 @@ def _gen_loop_start(
|
|||
out_shape = get_node_shape(chunk_output)
|
||||
out_str = str(list(out_shape))
|
||||
context = (
|
||||
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range"
|
||||
% (out_str, input_node.name, input_node.name, chunk_size)
|
||||
)
|
||||
"chunk_result = torch.empty(%s, dtype=%s.dtype, device=%s.device); chunk_size = %d\nfor chunk_idx in range" %
|
||||
(out_str, input_node.name, input_node.name, chunk_size))
|
||||
context += "(0, %d, chunk_size):\n" % (out_shape[chunk_ouput_dim])
|
||||
return context
|
||||
|
||||
|
@ -105,26 +102,17 @@ def _gen_loop_end(
|
|||
chunk_outputs_name = chunk_outputs.name
|
||||
chunk_outputs_idx = find_idx_by_name(chunk_outputs_name, node_list)
|
||||
chunk_output_shape = chunk_outputs.meta["tensor_meta"].shape
|
||||
chunk_slice = _gen_chunk_slice_dim(
|
||||
chunk_outputs_dim, "chunk_idx", chunk_output_shape
|
||||
)
|
||||
chunk_slice = _gen_chunk_slice_dim(chunk_outputs_dim, "chunk_idx", chunk_output_shape)
|
||||
context = " chunk_result%s = %s; %s = None\n" % (
|
||||
chunk_slice,
|
||||
chunk_outputs_name,
|
||||
chunk_outputs_name,
|
||||
)
|
||||
context += (
|
||||
chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None"
|
||||
)
|
||||
context += (chunk_outputs_name + " = chunk_result; chunk_result = None; chunk_size = None")
|
||||
|
||||
# determine if its the last use for chunk input
|
||||
for chunk_input in chunk_inputs + chunk_non_compute_inputs:
|
||||
if all(
|
||||
[
|
||||
find_idx_by_name(user.name, node_list) <= chunk_outputs_idx
|
||||
for user in chunk_input.users.keys()
|
||||
]
|
||||
):
|
||||
if all([find_idx_by_name(user.name, node_list) <= chunk_outputs_idx for user in chunk_input.users.keys()]):
|
||||
context += "; %s = None" % chunk_input.name
|
||||
|
||||
context += "\n"
|
||||
|
@ -171,17 +159,10 @@ def _replace_ones_like(
|
|||
chunk_dim = chunk_infos[region_idx]["node_chunk_dim"][meta_node]["chunk_dim"]
|
||||
if get_node_shape(meta_node)[chunk_dim] != 1:
|
||||
source_node = meta_node.args[0].args[0]
|
||||
if (
|
||||
source_node not in chunk_infos[region_idx]["node_chunk_dim"]
|
||||
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"]
|
||||
is None
|
||||
):
|
||||
chunk_slice = _gen_chunk_slice_dim(
|
||||
chunk_dim, "chunk_idx", get_node_shape(node)
|
||||
)
|
||||
body[-1] = _replace_name(
|
||||
body[-1], node.args[0].name, node.args[0].name + chunk_slice
|
||||
)
|
||||
if (source_node not in chunk_infos[region_idx]["node_chunk_dim"]
|
||||
or chunk_infos[region_idx]["node_chunk_dim"][source_node]["chunk_dim"] is None):
|
||||
chunk_slice = _gen_chunk_slice_dim(chunk_dim, "chunk_idx", get_node_shape(node))
|
||||
body[-1] = _replace_name(body[-1], node.args[0].name, node.args[0].name + chunk_slice)
|
||||
return body
|
||||
|
||||
|
||||
|
@ -198,12 +179,8 @@ def _replace_input_node(
|
|||
for input_node_idx, input_node in enumerate(chunk_inputs[region_idx]):
|
||||
for idx, dim in chunk_inputs_dim[region_idx][input_node_idx].items():
|
||||
if idx == node_idx:
|
||||
chunk_slice = _gen_chunk_slice_dim(
|
||||
dim[0], "chunk_idx", get_node_shape(input_node)
|
||||
)
|
||||
body[-1] = _replace_name(
|
||||
body[-1], input_node.name, input_node.name + chunk_slice
|
||||
)
|
||||
chunk_slice = _gen_chunk_slice_dim(dim[0], "chunk_idx", get_node_shape(input_node))
|
||||
body[-1] = _replace_name(body[-1], input_node.name, input_node.name + chunk_slice)
|
||||
return body
|
||||
|
||||
|
||||
|
@ -236,14 +213,10 @@ def emit_code_with_chunk(
|
|||
chunk_ends = [i["region"][1] for i in chunk_infos]
|
||||
|
||||
# chunk inputs
|
||||
chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
|
||||
chunk_inputs_non_chunk = [
|
||||
i["inputs_non_chunk"] for i in chunk_infos
|
||||
] # input without chunk
|
||||
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
|
||||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [
|
||||
j.name for i in chunk_inputs_non_chunk for j in i
|
||||
]
|
||||
chunk_inputs = [i["inputs"] for i in chunk_infos] # input with chunk
|
||||
chunk_inputs_non_chunk = [i["inputs_non_chunk"] for i in chunk_infos] # input without chunk
|
||||
chunk_inputs_dim = [i["inputs_dim"] for i in chunk_infos] # input chunk dim
|
||||
chunk_inputs_names = [j.name for i in chunk_inputs for j in i] + [j.name for i in chunk_inputs_non_chunk for j in i]
|
||||
|
||||
# chunk outputs
|
||||
chunk_outputs = [i["outputs"][0] for i in chunk_infos]
|
||||
|
@ -267,23 +240,16 @@ def emit_code_with_chunk(
|
|||
chunk_outputs[region_idx],
|
||||
chunk_outputs_dim[region_idx],
|
||||
chunk_infos[region_idx]["chunk_size"],
|
||||
)
|
||||
)
|
||||
))
|
||||
|
||||
if within_chunk_region:
|
||||
emit_node_func(node, body)
|
||||
# replace input var with chunk var
|
||||
body = _replace_input_node(
|
||||
chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body
|
||||
)
|
||||
body = _replace_input_node(chunk_inputs, region_idx, chunk_inputs_dim, node_idx, body)
|
||||
# ones like
|
||||
body = _replace_ones_like(
|
||||
search_chunk, chunk_infos, region_idx, node_idx, node, body
|
||||
)
|
||||
body = _replace_ones_like(search_chunk, chunk_infos, region_idx, node_idx, node, body)
|
||||
# reassgin reshape size
|
||||
body[-1] = _replace_reshape_size(
|
||||
body[-1], node.name, chunk_infos[region_idx]["reshape_size"]
|
||||
)
|
||||
body[-1] = _replace_reshape_size(body[-1], node.name, chunk_infos[region_idx]["reshape_size"])
|
||||
body[-1] = " " + body[-1]
|
||||
delete_unused_value_func(node, body, chunk_inputs_names)
|
||||
else:
|
||||
|
@ -300,8 +266,7 @@ def emit_code_with_chunk(
|
|||
chunk_outputs[region_idx],
|
||||
chunk_outputs_dim[region_idx],
|
||||
node_list,
|
||||
)
|
||||
)
|
||||
))
|
||||
within_chunk_region = False
|
||||
|
||||
node_idx += 1
|
||||
|
@ -310,18 +275,14 @@ def emit_code_with_chunk(
|
|||
if CODEGEN_AVAILABLE:
|
||||
|
||||
class AutoChunkCodeGen(CodeGen):
|
||||
|
||||
def __init__(self, meta_graph, max_memory=None, print_mem=False):
|
||||
super().__init__()
|
||||
self.meta_graph = meta_graph
|
||||
self.max_memory = max_memory
|
||||
self.meta_node = list(meta_graph.graph.nodes)
|
||||
# find the chunk regions
|
||||
self.search_chunk = SearchChunk(meta_graph, max_memory, print_mem)
|
||||
self.chunk_infos = self.search_chunk.search_region()
|
||||
|
||||
def _gen_python_code(
|
||||
self, nodes, root_module: str, namespace: _Namespace
|
||||
) -> PythonCode:
|
||||
def _gen_python_code(self, nodes, root_module: str, namespace: _Namespace) -> PythonCode:
|
||||
free_vars: List[str] = []
|
||||
body: List[str] = []
|
||||
globals_: Dict[str, Any] = {}
|
||||
|
@ -338,9 +299,7 @@ if CODEGEN_AVAILABLE:
|
|||
|
||||
Returns: the global name that should be used to reference 'obj' in generated source.
|
||||
"""
|
||||
if (
|
||||
_is_from_torch(obj) and obj != torch.device
|
||||
): # to support registering torch.device
|
||||
if (_is_from_torch(obj) and obj != torch.device): # to support registering torch.device
|
||||
# HACK: workaround for how torch custom ops are registered. We
|
||||
# can't import them like normal modules so they must retain their
|
||||
# fully qualified name.
|
||||
|
@ -356,9 +315,7 @@ if CODEGEN_AVAILABLE:
|
|||
return global_name
|
||||
|
||||
# set _custom_builtins here so that we needn't import colossalai in forward
|
||||
_custom_builtins["colossalai"] = _CustomBuiltin(
|
||||
"import colossalai", colossalai
|
||||
)
|
||||
_custom_builtins["colossalai"] = _CustomBuiltin("import colossalai", colossalai)
|
||||
|
||||
# Pre-fill the globals table with registered builtins.
|
||||
for name, (_, obj) in _custom_builtins.items():
|
||||
|
@ -394,9 +351,8 @@ if CODEGEN_AVAILABLE:
|
|||
# Common case: this is a regular module name like 'foo.bar.baz'
|
||||
return add_global(typename, o)
|
||||
|
||||
def _format_args(
|
||||
args: Tuple[Argument, ...], kwargs: Dict[str, Argument]
|
||||
) -> str:
|
||||
def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str:
|
||||
|
||||
def _get_repr(arg):
|
||||
# Handle NamedTuples (if it has `_fields`) via add_global.
|
||||
if isinstance(arg, tuple) and hasattr(arg, "_fields"):
|
||||
|
@ -444,26 +400,18 @@ if CODEGEN_AVAILABLE:
|
|||
nodes_to_delete = user_to_last_uses.get(user, [])
|
||||
nodes_to_delete = [i for i in nodes_to_delete if i.name not in to_keep]
|
||||
if len(nodes_to_delete):
|
||||
to_delete_str = " = ".join(
|
||||
[repr(n) for n in nodes_to_delete] + ["None"]
|
||||
)
|
||||
to_delete_str = " = ".join([repr(n) for n in nodes_to_delete] + ["None"])
|
||||
body.append(f"; {to_delete_str}\n")
|
||||
else:
|
||||
body.append("\n")
|
||||
|
||||
# NOTE: we add a variable to distinguish body and ckpt_func
|
||||
def emit_node(node: Node, body):
|
||||
maybe_type_annotation = (
|
||||
"" if node.type is None else f" : {type_repr(node.type)}"
|
||||
)
|
||||
maybe_type_annotation = ("" if node.type is None else f" : {type_repr(node.type)}")
|
||||
if node.op == "placeholder":
|
||||
assert isinstance(node.target, str)
|
||||
maybe_default_arg = (
|
||||
"" if not node.args else f" = {repr(node.args[0])}"
|
||||
)
|
||||
free_vars.append(
|
||||
f"{node.target}{maybe_type_annotation}{maybe_default_arg}"
|
||||
)
|
||||
maybe_default_arg = ("" if not node.args else f" = {repr(node.args[0])}")
|
||||
free_vars.append(f"{node.target}{maybe_type_annotation}{maybe_default_arg}")
|
||||
raw_name = node.target.replace("*", "")
|
||||
if raw_name != repr(node):
|
||||
body.append(f"{repr(node)} = {raw_name}\n")
|
||||
|
@ -472,68 +420,46 @@ if CODEGEN_AVAILABLE:
|
|||
assert isinstance(node.target, str)
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.target)}"
|
||||
f"({_format_args(node.args[1:], node.kwargs)})"
|
||||
)
|
||||
f"({_format_args(node.args[1:], node.kwargs)})")
|
||||
return
|
||||
elif node.op == "call_function":
|
||||
assert callable(node.target)
|
||||
# pretty print operators
|
||||
if (
|
||||
node.target.__module__ == "_operator"
|
||||
and node.target.__name__ in magic_methods
|
||||
):
|
||||
if (node.target.__module__ == "_operator" and node.target.__name__ in magic_methods):
|
||||
assert isinstance(node.args, tuple)
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = "
|
||||
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}"
|
||||
)
|
||||
body.append(f"{repr(node)}{maybe_type_annotation} = "
|
||||
f"{magic_methods[node.target.__name__].format(*(repr(a) for a in node.args))}")
|
||||
return
|
||||
|
||||
# pretty print inplace operators; required for jit.script to work properly
|
||||
# not currently supported in normal FX graphs, but generated by torchdynamo
|
||||
if (
|
||||
node.target.__module__ == "_operator"
|
||||
and node.target.__name__ in inplace_methods
|
||||
):
|
||||
body.append(
|
||||
f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
|
||||
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}"
|
||||
)
|
||||
if (node.target.__module__ == "_operator" and node.target.__name__ in inplace_methods):
|
||||
body.append(f"{inplace_methods[node.target.__name__].format(*(repr(a) for a in node.args))}; "
|
||||
f"{repr(node)}{maybe_type_annotation} = {repr(node.args[0])}")
|
||||
return
|
||||
|
||||
qualified_name = _get_qualified_name(node.target)
|
||||
global_name = add_global(qualified_name, node.target)
|
||||
# special case for getattr: node.args could be 2-argument or 3-argument
|
||||
# 2-argument: attribute access; 3-argument: fall through to attrib function call with default value
|
||||
if (
|
||||
global_name == "getattr"
|
||||
and isinstance(node.args, tuple)
|
||||
and isinstance(node.args[1], str)
|
||||
and node.args[1].isidentifier()
|
||||
and len(node.args) == 2
|
||||
):
|
||||
if (global_name == "getattr" and isinstance(node.args, tuple) and isinstance(node.args[1], str)
|
||||
and node.args[1].isidentifier() and len(node.args) == 2):
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}"
|
||||
)
|
||||
f"{repr(node)}{maybe_type_annotation} = {_format_target(repr(node.args[0]), node.args[1])}")
|
||||
return
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})"
|
||||
)
|
||||
f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})")
|
||||
if node.meta.get("is_wrapped", False):
|
||||
wrapped_fns.setdefault(global_name)
|
||||
return
|
||||
elif node.op == "call_module":
|
||||
assert isinstance(node.target, str)
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = "
|
||||
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})"
|
||||
)
|
||||
body.append(f"{repr(node)}{maybe_type_annotation} = "
|
||||
f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})")
|
||||
return
|
||||
elif node.op == "get_attr":
|
||||
assert isinstance(node.target, str)
|
||||
body.append(
|
||||
f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}"
|
||||
)
|
||||
body.append(f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}")
|
||||
return
|
||||
elif node.op == "output":
|
||||
if node.type is not None:
|
||||
|
@ -564,9 +490,7 @@ if CODEGEN_AVAILABLE:
|
|||
|
||||
if len(wrapped_fns) > 0:
|
||||
wrap_name = add_global("wrap", torch.fx.wrap)
|
||||
wrap_stmts = "\n".join(
|
||||
[f'{wrap_name}("{name}")' for name in wrapped_fns]
|
||||
)
|
||||
wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns])
|
||||
else:
|
||||
wrap_stmts = ""
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ from .utils import (
|
|||
|
||||
|
||||
class TraceFlow(object):
|
||||
|
||||
def __init__(self, trace_indice: TraceIndice) -> None:
|
||||
self.trace_indice = trace_indice
|
||||
|
||||
|
@ -28,9 +29,7 @@ class TraceFlow(object):
|
|||
start_node_idx = find_idx_by_name(start_node.name, self.trace_indice.node_list)
|
||||
end_node_trace = self.trace_indice._find_trace_from_node(end_node)
|
||||
end_node_trace_source = end_node_trace["source"][end_dim]
|
||||
sorted_source = sorted(
|
||||
end_node_trace_source.items(), key=lambda d: d[0], reverse=True
|
||||
)
|
||||
sorted_source = sorted(end_node_trace_source.items(), key=lambda d: d[0], reverse=True)
|
||||
for node_idx, node_dim in sorted_source:
|
||||
if node_idx == start_node_idx and start_dim in node_dim:
|
||||
return True
|
||||
|
@ -70,10 +69,8 @@ class TraceFlow(object):
|
|||
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
|
||||
node_trace_source = self.trace_indice._find_source_trace_from_node(node)
|
||||
for node_dim in range(len(get_node_shape(node))):
|
||||
if (
|
||||
input_node_idx in node_trace_source[node_dim]
|
||||
and input_dim[0] in node_trace_source[node_dim][input_node_idx]
|
||||
):
|
||||
if (input_node_idx in node_trace_source[node_dim]
|
||||
and input_dim[0] in node_trace_source[node_dim][input_node_idx]):
|
||||
return node_dim
|
||||
return None
|
||||
|
||||
|
@ -81,15 +78,11 @@ class TraceFlow(object):
|
|||
input_dim_after_node = {}
|
||||
for input_node_idx, input_node in enumerate(chunk_infos["inputs"]):
|
||||
for k, v in chunk_infos["inputs_dim"][input_node_idx].items():
|
||||
inherit_dim = self._find_inherit_dim(
|
||||
input_node, v, self.trace_indice.node_list[k]
|
||||
)
|
||||
inherit_dim = self._find_inherit_dim(input_node, v, self.trace_indice.node_list[k])
|
||||
if inherit_dim:
|
||||
input_dim_after_node[k] = inherit_dim
|
||||
|
||||
for node in self.trace_indice.node_list[
|
||||
chunk_infos["region"][0] : chunk_infos["region"][1] + 1
|
||||
]:
|
||||
for node in self.trace_indice.node_list[chunk_infos["region"][0]:chunk_infos["region"][1] + 1]:
|
||||
if is_non_compute_node_except_placeholder(node):
|
||||
continue
|
||||
count = 0
|
||||
|
@ -159,9 +152,7 @@ class TraceFlow(object):
|
|||
if arg_node in all_node_info:
|
||||
if all_node_info[arg_node]["chunk_dim"] != arg_dim:
|
||||
return False
|
||||
all_node_info[arg_node]["fix_dim"] = list(
|
||||
set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim)
|
||||
)
|
||||
all_node_info[arg_node]["fix_dim"] = list(set(all_node_info[arg_node]["fix_dim"] + arg_fix_dim))
|
||||
# else add it to list
|
||||
else:
|
||||
all_node_info[arg_node] = {"chunk_dim": arg_dim, "fix_dim": arg_fix_dim}
|
||||
|
@ -170,9 +161,7 @@ class TraceFlow(object):
|
|||
return True
|
||||
|
||||
def _get_all_node_info(self, end_dim, start_idx, end_idx):
|
||||
cur_node_list = [
|
||||
self.trace_indice.node_list[end_idx]
|
||||
] # start from the last node
|
||||
cur_node_list = [self.trace_indice.node_list[end_idx]] # start from the last node
|
||||
all_node_info = {cur_node_list[0]: {"chunk_dim": end_dim, "fix_dim": []}}
|
||||
|
||||
while len(cur_node_list) > 0:
|
||||
|
@ -183,12 +172,8 @@ class TraceFlow(object):
|
|||
cur_node_chunk_dim = all_node_info[cur_node]["chunk_dim"]
|
||||
cur_node_fix_dim = all_node_info[cur_node]["fix_dim"]
|
||||
if cur_node_chunk_dim:
|
||||
cur_node_compute = self.trace_indice._find_compute_trace_from_node(
|
||||
cur_node
|
||||
)
|
||||
cur_node_source = self.trace_indice._find_source_trace_from_node(
|
||||
cur_node
|
||||
)
|
||||
cur_node_compute = self.trace_indice._find_compute_trace_from_node(cur_node)
|
||||
cur_node_source = self.trace_indice._find_source_trace_from_node(cur_node)
|
||||
else:
|
||||
cur_node_compute = cur_node_source = None
|
||||
|
||||
|
@ -215,15 +200,9 @@ class TraceFlow(object):
|
|||
return None
|
||||
|
||||
if len(arg_list) == 2:
|
||||
if any(i in cur_node.name for i in ["add", "mul"]):
|
||||
if any(i in cur_node.name for i in ["add", "mul", "truediv"]):
|
||||
for arg in arg_list:
|
||||
if not (
|
||||
start_idx
|
||||
<= find_idx_by_name(
|
||||
arg.name, self.trace_indice.node_list
|
||||
)
|
||||
< end_idx
|
||||
):
|
||||
if not (start_idx <= find_idx_by_name(arg.name, self.trace_indice.node_list) < end_idx):
|
||||
continue
|
||||
arg_chunk_dim = all_node_info[arg]["chunk_dim"]
|
||||
arg_fix_dim = all_node_info[arg]["fix_dim"]
|
||||
|
@ -249,9 +228,7 @@ class TraceFlow(object):
|
|||
remove_inputs = []
|
||||
for input_node in inputs:
|
||||
input_dict = {}
|
||||
input_node_idx = find_idx_by_name(
|
||||
input_node.name, self.trace_indice.node_list
|
||||
)
|
||||
input_node_idx = find_idx_by_name(input_node.name, self.trace_indice.node_list)
|
||||
for user in input_node.users.keys():
|
||||
if is_non_compute_node(user):
|
||||
continue
|
||||
|
@ -259,9 +236,7 @@ class TraceFlow(object):
|
|||
if start_idx <= user_idx <= end_idx:
|
||||
chunk_dim = all_node_info[user]["chunk_dim"]
|
||||
if chunk_dim is not None:
|
||||
user_source = self.trace_indice._find_source_trace_from_node(
|
||||
user
|
||||
)[chunk_dim]
|
||||
user_source = self.trace_indice._find_source_trace_from_node(user)[chunk_dim]
|
||||
if input_node_idx in user_source:
|
||||
input_dict[user_idx] = user_source[input_node_idx]
|
||||
else:
|
||||
|
@ -284,7 +259,7 @@ class TraceFlow(object):
|
|||
maybe_prepose_nodes.sort(
|
||||
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list),
|
||||
reverse=True,
|
||||
) # from last node to first node
|
||||
) # from last node to first node
|
||||
prepose_nodes = []
|
||||
# set every node as root, search its args, if all legal, turn root and args as prepose nodes
|
||||
while len(maybe_prepose_nodes) > 0:
|
||||
|
@ -305,13 +280,8 @@ class TraceFlow(object):
|
|||
if type(cur_prepose_node_arg) != type(cur_prepose_node):
|
||||
continue
|
||||
# out of loop
|
||||
if not (
|
||||
start_idx
|
||||
<= find_idx_by_name(
|
||||
cur_prepose_node_arg.name, self.trace_indice.node_list
|
||||
)
|
||||
< end_idx
|
||||
):
|
||||
if not (start_idx <= find_idx_by_name(cur_prepose_node_arg.name, self.trace_indice.node_list) <
|
||||
end_idx):
|
||||
continue
|
||||
# compute op in loop
|
||||
elif cur_prepose_node_arg in all_node_info:
|
||||
|
@ -335,15 +305,13 @@ class TraceFlow(object):
|
|||
if n in maybe_prepose_nodes:
|
||||
maybe_prepose_nodes.remove(n)
|
||||
# sort by index
|
||||
prepose_nodes.sort(
|
||||
key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list)
|
||||
)
|
||||
prepose_nodes.sort(key=lambda x: find_idx_by_name(x.name, self.trace_indice.node_list))
|
||||
|
||||
return prepose_nodes
|
||||
|
||||
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
|
||||
# we need to log input nodes to avoid deleteing them in the loop
|
||||
chunk_node_list = self.trace_indice.node_list[start_idx : end_idx + 1]
|
||||
chunk_node_list = self.trace_indice.node_list[start_idx:end_idx + 1]
|
||||
# also need to get some prepose node's arg out of non_chunk_inputs
|
||||
for n in chunk_info["args"]["prepose_nodes"]:
|
||||
chunk_node_list.remove(n)
|
||||
|
@ -354,9 +322,7 @@ class TraceFlow(object):
|
|||
return chunk_info
|
||||
|
||||
def flow_search(self, start_idx, start_dim, end_idx, end_dim):
|
||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(
|
||||
self.trace_indice.node_list[start_idx : end_idx + 1]
|
||||
)
|
||||
inputs, outputs = find_chunk_compute_input_and_output_nodes(self.trace_indice.node_list[start_idx:end_idx + 1])
|
||||
# only single ouput
|
||||
if len(outputs) > 1:
|
||||
return None
|
||||
|
@ -367,9 +333,7 @@ class TraceFlow(object):
|
|||
return None
|
||||
|
||||
# get input nodes' chunk dim
|
||||
inputs, inputs_dim = self._get_input_nodes_dim(
|
||||
inputs, start_idx, end_idx, all_node_info
|
||||
)
|
||||
inputs, inputs_dim = self._get_input_nodes_dim(inputs, start_idx, end_idx, all_node_info)
|
||||
if inputs is None:
|
||||
return None
|
||||
|
||||
|
@ -385,9 +349,7 @@ class TraceFlow(object):
|
|||
}
|
||||
|
||||
# move useless nodes ahead of loop
|
||||
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(
|
||||
all_node_info, start_idx, end_idx
|
||||
)
|
||||
chunk_info["args"]["prepose_nodes"] = self._get_prepose_nodes(all_node_info, start_idx, end_idx)
|
||||
|
||||
# find non chunk inputs
|
||||
chunk_info = self._get_non_chunk_inputs(chunk_info, start_idx, end_idx)
|
||||
|
@ -400,10 +362,8 @@ class TraceFlow(object):
|
|||
def _reassgin_reshape_size(self, chunk_info):
|
||||
chunk_region = chunk_info["region"]
|
||||
reshape_size = {}
|
||||
chunk_shape = get_node_shape(chunk_info["outputs"][0])[
|
||||
chunk_info["outputs_dim"]
|
||||
]
|
||||
for node in self.trace_indice.node_list[chunk_region[0] : chunk_region[1] + 1]:
|
||||
chunk_shape = get_node_shape(chunk_info["outputs"][0])[chunk_info["outputs_dim"]]
|
||||
for node in self.trace_indice.node_list[chunk_region[0]:chunk_region[1] + 1]:
|
||||
if any(i in node.name for i in ["reshape", "view"]):
|
||||
reshape_args = node.args[1:]
|
||||
reshape_log = self.trace_indice.indice_view_list[node]
|
||||
|
@ -413,8 +373,6 @@ class TraceFlow(object):
|
|||
if reshape_arg_dim in reshape_log["dim_to"]:
|
||||
continue
|
||||
if reshape_arg_dim == chunk_dim:
|
||||
reshape_size[node.name][reshape_arg.name] = (
|
||||
"min(chunk_size, %d - chunk_idx)" % chunk_shape
|
||||
)
|
||||
reshape_size[node.name][reshape_arg.name] = ("min(chunk_size, %d - chunk_idx)" % chunk_shape)
|
||||
chunk_info["reshape_size"] = reshape_size
|
||||
return chunk_info
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Dict, List, Tuple
|
|||
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .utils import find_idx_by_name, get_node_shape
|
||||
from .utils import find_first_tensor_arg, find_idx_by_name, get_node_shape, unflat_list
|
||||
|
||||
|
||||
class TraceIndice(object):
|
||||
|
@ -79,9 +79,7 @@ class TraceIndice(object):
|
|||
node_from_trace = self._find_trace_from_node(node_from)
|
||||
node_to_trace = self._find_trace_from_node(node_to)
|
||||
node_to_trace["indice"][node_to_dim] = node_from_trace["indice"][node_from_dim]
|
||||
node_to_trace["compute"][node_to_dim] = copy.deepcopy(
|
||||
node_from_trace["compute"][node_from_dim]
|
||||
)
|
||||
node_to_trace["compute"][node_to_dim] = copy.deepcopy(node_from_trace["compute"][node_from_dim])
|
||||
self._add_source(node_from, node_from_dim, node_to, node_to_dim, init=True)
|
||||
|
||||
def _inherit_all_computation(self, node_from, node_to):
|
||||
|
@ -209,7 +207,7 @@ class TraceIndice(object):
|
|||
node_idx (int)
|
||||
"""
|
||||
if input_node == None:
|
||||
input_node = node.args[0]
|
||||
input_node = find_first_tensor_arg(node)
|
||||
input_node_idx = find_idx_by_name(input_node.name, self.node_list)
|
||||
input_node_idx_trace = self.indice_trace_list[input_node_idx]["indice"]
|
||||
|
||||
|
@ -227,6 +225,8 @@ class TraceIndice(object):
|
|||
node_idx (int)
|
||||
"""
|
||||
shape = node.meta["tensor_meta"].shape
|
||||
if shape is None:
|
||||
return
|
||||
new_trace = []
|
||||
for _ in shape:
|
||||
new_trace.append(self._add_indice())
|
||||
|
@ -259,7 +259,7 @@ class TraceIndice(object):
|
|||
node (node)
|
||||
node_idx (int)
|
||||
"""
|
||||
permute_dim = node.args[1:]
|
||||
permute_dim = unflat_list(node.args[1:])
|
||||
input_node = node.args[0]
|
||||
|
||||
self._assign_indice_as_input(node, node_idx, input_node)
|
||||
|
@ -359,6 +359,15 @@ class TraceIndice(object):
|
|||
left, right = patterns.split("->")
|
||||
left = left.split(",")
|
||||
|
||||
if '...' in right:
|
||||
replace_list = "!@#$%^&*"
|
||||
target_len = len(get_node_shape(node))
|
||||
add_len = target_len - len(right) + 3
|
||||
replace_str = replace_list[:add_len]
|
||||
right = right.replace("...", replace_str)
|
||||
for ll in range(len(left)):
|
||||
left[ll] = left[ll].replace("...", replace_str)
|
||||
|
||||
all_index = []
|
||||
for i in left:
|
||||
for c in i:
|
||||
|
@ -369,9 +378,7 @@ class TraceIndice(object):
|
|||
for left_idx, left_str in enumerate(left):
|
||||
if right_indice in left_str:
|
||||
source_idx = left_str.index(right_indice)
|
||||
self._inherit_indice(
|
||||
input_nodes[left_idx], source_idx, node, right_idx
|
||||
)
|
||||
self._inherit_indice(input_nodes[left_idx], source_idx, node, right_idx)
|
||||
|
||||
def _assign_softmax_indice(self, node, idx):
|
||||
"""
|
||||
|
@ -440,11 +447,12 @@ class TraceIndice(object):
|
|||
origin_node = node.args[0]
|
||||
origin_shape = origin_node.meta["tensor_meta"].shape
|
||||
target_shape = []
|
||||
for i in range(1, len(node.args)):
|
||||
if isinstance(node.args[i], int):
|
||||
target_shape.append(node.args[i])
|
||||
unflated_args = unflat_list(node.args)
|
||||
for i in range(1, len(unflated_args)):
|
||||
if isinstance(unflated_args[i], int):
|
||||
target_shape.append(unflated_args[i])
|
||||
else:
|
||||
target_shape.append(node.args[i].meta["fwd_out"][0])
|
||||
target_shape.append(unflated_args[i].meta["fwd_out"][0])
|
||||
|
||||
# compute the value of -1
|
||||
if -1 in target_shape:
|
||||
|
@ -472,13 +480,7 @@ class TraceIndice(object):
|
|||
dim_to = [dim_equal.index(False), dim_equal.index(False) + 1]
|
||||
self._del_dim(node_idx, -1)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"shape"
|
||||
+ str(origin_shape)
|
||||
+ "and"
|
||||
+ str(target_shape)
|
||||
+ "view not implemented"
|
||||
)
|
||||
raise NotImplementedError("shape" + str(origin_shape) + "and" + str(target_shape) + "view not implemented")
|
||||
|
||||
# get new indice
|
||||
origin_trace = self._find_indice_trace_from_node(origin_node)
|
||||
|
@ -521,6 +523,8 @@ class TraceIndice(object):
|
|||
self._assign_unsqueeze_indice(node, idx)
|
||||
elif any(i in node.name for i in ["to", "contiguous"]):
|
||||
self._assgin_no_change_indice(node, idx)
|
||||
elif "new_ones" in node.name:
|
||||
self._assign_ones_like_indice(node, idx)
|
||||
else:
|
||||
raise NotImplementedError(node.name, "method not implemented yet!")
|
||||
elif node.op == "call_function":
|
||||
|
@ -530,7 +534,7 @@ class TraceIndice(object):
|
|||
self._assign_matmul_indice(node, idx)
|
||||
elif "softmax" in node.name:
|
||||
self._assign_softmax_indice(node, idx)
|
||||
elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu"]):
|
||||
elif any(n in node.name for n in ["mul", "add", "sigmoid", "relu", "sub", "truediv"]):
|
||||
self._assign_elementwise_indice(node, idx)
|
||||
elif "ones_like" in node.name:
|
||||
self._assign_ones_like_indice(node, idx)
|
||||
|
@ -538,21 +542,21 @@ class TraceIndice(object):
|
|||
self._assign_dropout_indice(node, idx)
|
||||
elif "einsum" in node.name:
|
||||
self._assign_einsum_indice(node, idx)
|
||||
elif "getattr" in node.name:
|
||||
continue # get attr like shape
|
||||
elif "getitem" in node.name:
|
||||
continue # get item in list
|
||||
elif "layer_norm" in node.name:
|
||||
self._assign_layernorm_indice(node, idx)
|
||||
elif any(i in node.name for i in ["getattr", "getitem", "eq", "_assert"]):
|
||||
continue
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
node.name, "function not implemented yet!"
|
||||
)
|
||||
raise NotImplementedError(node.name, "function not implemented yet!")
|
||||
elif node.op == "call_module":
|
||||
if any(n in node.name for n in ["layernorm", "norm"]):
|
||||
self._assign_layernorm_indice(node, idx)
|
||||
elif any(n in node.name for n in ["sigmoid", "dropout", "relu"]):
|
||||
self._assign_elementwise_indice(node, idx)
|
||||
else:
|
||||
raise NotImplementedError(node.name, "module not implemented yet!")
|
||||
elif node.op == "get_attr":
|
||||
self._assign_all_indice(node, idx) # get param
|
||||
self._assign_all_indice(node, idx) # get param
|
||||
elif node.op == "output":
|
||||
continue
|
||||
else:
|
||||
|
|
|
@ -3,10 +3,32 @@ from typing import Any, Callable, Dict, Iterable, List, Tuple
|
|||
from torch.fx.node import Node
|
||||
|
||||
|
||||
def unflat_list(inputs):
|
||||
"""
|
||||
unflat a list by recursion
|
||||
"""
|
||||
res = []
|
||||
for i in inputs:
|
||||
if isinstance(i, list) or isinstance(i, set) or isinstance(i, tuple):
|
||||
res.extend(unflat_list(i))
|
||||
else:
|
||||
res.append(i)
|
||||
return res
|
||||
|
||||
|
||||
def find_first_tensor_arg(node):
|
||||
"""
|
||||
Find the first input tensor arg for a node
|
||||
"""
|
||||
for arg in node.args:
|
||||
if type(arg) == type(node):
|
||||
return arg
|
||||
raise RuntimeError()
|
||||
|
||||
|
||||
def is_non_compute_node(node):
|
||||
if any(i in node.op for i in ["placeholder", "get_attr", "output"]) or any(
|
||||
i in node.name for i in ["getitem", "getattr"]
|
||||
):
|
||||
i in node.name for i in ["getitem", "getattr"]):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -18,17 +40,13 @@ def get_node_shape(node):
|
|||
|
||||
|
||||
def is_non_compute_node_except_placeholder(node):
|
||||
if any(i in node.op for i in ["get_attr", "output"]) or any(
|
||||
i in node.name for i in ["getitem", "getattr"]
|
||||
):
|
||||
if any(i in node.op for i in ["get_attr", "output"]) or any(i in node.name for i in ["getitem", "getattr"]):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_non_compute_node_except_placeholder_output(node):
|
||||
if any(i in node.op for i in ["get_attr"]) or any(
|
||||
i in node.name for i in ["getitem", "getattr"]
|
||||
):
|
||||
if any(i in node.op for i in ["get_attr"]) or any(i in node.name for i in ["getitem", "getattr"]):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -74,22 +92,16 @@ def find_chunk_compute_input_and_output_nodes(nodes: List[Node]):
|
|||
# we treat that input node as the input of the checkpoint function
|
||||
for node in nodes:
|
||||
for input_node in node._input_nodes.keys():
|
||||
if (
|
||||
input_node not in nodes
|
||||
and input_node not in input_nodes
|
||||
and not is_non_compute_node_except_placeholder(input_node)
|
||||
):
|
||||
if (input_node not in nodes and input_node not in input_nodes
|
||||
and not is_non_compute_node_except_placeholder(input_node)):
|
||||
input_nodes.append(input_node)
|
||||
|
||||
# if a node has a user node which is not in the node list
|
||||
# we treat that user node as the node receiving the current node output
|
||||
for node in nodes:
|
||||
for output_node in node.users.keys():
|
||||
if (
|
||||
output_node not in nodes
|
||||
and node not in output_nodes
|
||||
and not is_non_compute_node_except_placeholder_output(output_node)
|
||||
):
|
||||
if (output_node not in nodes and node not in output_nodes
|
||||
and not is_non_compute_node_except_placeholder_output(output_node)):
|
||||
output_nodes.append(node)
|
||||
|
||||
return input_nodes, output_nodes
|
||||
|
|
|
@ -249,6 +249,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||
aten.sum.default,
|
||||
aten.sum.dim_IntList,
|
||||
aten.mean.dim,
|
||||
aten.sub.Tensor,
|
||||
aten.sub_.Tensor,
|
||||
|
||||
# activation op
|
||||
aten.hardswish.default,
|
||||
|
@ -313,7 +315,8 @@ if version.parse(torch.__version__) >= version.parse('1.12.0'):
|
|||
aten.where.self,
|
||||
aten.zero_.default,
|
||||
aten.zeros_like.default,
|
||||
]
|
||||
aten.fill_.Scalar
|
||||
] # yapf: disable
|
||||
|
||||
for op in zero_flop_aten:
|
||||
flop_mapping[op] = zero_flop_jit
|
||||
|
|
|
@ -2,14 +2,13 @@ import time
|
|||
|
||||
import torch
|
||||
import torch.fx
|
||||
from simple_evoformer import base_evoformer, openfold_evoformer
|
||||
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from tests.test_autochunk.evoformer.evoformer import evoformer_base
|
||||
from tests.test_autochunk.openfold.evoformer import EvoformerBlock
|
||||
|
||||
|
||||
def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=None):
|
||||
|
@ -34,10 +33,7 @@ def _benchmark_evoformer(model: torch.nn.Module, node, pair, title, chunk_size=N
|
|||
time2 = time.time()
|
||||
|
||||
new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||
print(
|
||||
"%s: time %.4fs, mem %dMB"
|
||||
% (title, (time2 - time1) / loop, new_max_mem - now_mem)
|
||||
)
|
||||
print("%s: time %.4fs, mem %dMB" % (title, (time2 - time1) / loop, new_max_mem - now_mem))
|
||||
|
||||
|
||||
def _build_autochunk(model, max_memory, node, pair):
|
||||
|
@ -50,18 +46,14 @@ def _build_autochunk(model, max_memory, node, pair):
|
|||
},
|
||||
)
|
||||
|
||||
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
|
||||
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
|
||||
interp = MetaInfoProp(gm_prop)
|
||||
interp.propagate(
|
||||
MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
|
||||
)
|
||||
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
|
||||
|
||||
# now run it twice to get meta info in graph module, not necessary
|
||||
gm = torch.fx.GraphModule(model, graph)
|
||||
interp = MetaInfoProp(gm)
|
||||
interp.propagate(
|
||||
MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
|
||||
)
|
||||
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
|
||||
|
||||
# set code_gen
|
||||
codegen = AutoChunkCodeGen(gm_prop, max_memory, print_mem=False)
|
||||
|
@ -75,42 +67,22 @@ def _build_autochunk(model, max_memory, node, pair):
|
|||
return gm
|
||||
|
||||
|
||||
def _build_openfold():
|
||||
model = EvoformerBlock(
|
||||
c_m=256,
|
||||
c_z=128,
|
||||
c_hidden_msa_att=32,
|
||||
c_hidden_opm=32,
|
||||
c_hidden_mul=128,
|
||||
c_hidden_pair_att=32,
|
||||
no_heads_msa=8,
|
||||
no_heads_pair=4,
|
||||
transition_n=4,
|
||||
msa_dropout=0.15,
|
||||
pair_dropout=0.15,
|
||||
inf=1e4,
|
||||
eps=1e-4,
|
||||
is_multimer=False,
|
||||
).cuda()
|
||||
return model
|
||||
|
||||
|
||||
def benchmark_evoformer():
|
||||
# init data and model
|
||||
msa_len = 256
|
||||
pair_len = 512
|
||||
msa_len = 128
|
||||
pair_len = 256
|
||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
model = evoformer_base().cuda()
|
||||
model = base_evoformer().cuda()
|
||||
|
||||
# build autochunk model
|
||||
# max_memory = 1000 # MB, fit memory mode
|
||||
max_memory = None # min memory mode
|
||||
autochunk = _build_autochunk(evoformer_base().cuda(), max_memory, node, pair)
|
||||
max_memory = None # min memory mode
|
||||
autochunk = _build_autochunk(base_evoformer().cuda(), max_memory, node, pair)
|
||||
|
||||
# build openfold
|
||||
chunk_size = 64
|
||||
openfold = _build_openfold()
|
||||
openfold = openfold_evoformer().cuda()
|
||||
|
||||
# benchmark
|
||||
_benchmark_evoformer(model, node, pair, "base")
|
|
@ -1,59 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .msa import MSAStack
|
||||
from .ops import OutProductMean
|
||||
from .triangle import PairStack
|
||||
|
||||
|
||||
def print_memory(init_mem, text=None):
|
||||
now_mem = torch.cuda.memory_allocated() / 1024 ** 2 - init_mem
|
||||
max_mem = torch.cuda.max_memory_allocated() / 1024 ** 2 - init_mem
|
||||
print("%s now:%.2f max:%.2f" % ("" if text is None else text, now_mem, max_mem))
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
||||
class EvoformerBlock(nn.Module):
|
||||
|
||||
def __init__(self, d_node, d_pair):
|
||||
super(EvoformerBlock, self).__init__()
|
||||
|
||||
self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15)
|
||||
self.communication = OutProductMean(n_feat=d_node, n_feat_out=d_pair, n_feat_proj=32)
|
||||
self.pair_stack = PairStack(d_pair=d_pair)
|
||||
|
||||
def forward(self, node, pair):
|
||||
node = self.msa_stack(node, pair)
|
||||
pair = pair + self.communication(node)
|
||||
pair = self.pair_stack(pair)
|
||||
return node, pair
|
||||
|
||||
|
||||
class Evoformer(nn.Module):
|
||||
|
||||
def __init__(self, d_node, d_pair):
|
||||
super(Evoformer, self).__init__()
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
for _ in range(1):
|
||||
self.blocks.append(EvoformerBlock(d_node, d_pair))
|
||||
|
||||
def forward(self, node, pair):
|
||||
for b in self.blocks:
|
||||
node, pair = b(node, pair)
|
||||
return node, pair
|
||||
|
||||
|
||||
def evoformer_tiny():
|
||||
return Evoformer(d_node=64, d_pair=32)
|
||||
|
||||
|
||||
def evoformer_base():
|
||||
return Evoformer(d_node=256, d_pair=128)
|
||||
|
||||
|
||||
def evoformer_large():
|
||||
return Evoformer(d_node=512, d_pair=256)
|
||||
|
||||
|
||||
__all__ = ['Evoformer', 'evoformer_base', 'evoformer_large']
|
|
@ -1,29 +0,0 @@
|
|||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def glorot_uniform_af(x, gain=1.0):
|
||||
"""
|
||||
initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different:
|
||||
In PyTorch:
|
||||
[feature_out, feature_in, n_head ...]
|
||||
In Jax:
|
||||
[... n_head, feature_in, feature_out]
|
||||
However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like:
|
||||
[feature_in, n_head, feature_out]
|
||||
|
||||
In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors
|
||||
"""
|
||||
fan_in, fan_out = x.shape[-2:]
|
||||
if len(x.shape) > 2:
|
||||
receptive_field_size = np.prod(x.shape[:-2])
|
||||
fan_in *= receptive_field_size
|
||||
fan_out *= receptive_field_size
|
||||
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
|
||||
dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
||||
|
||||
nn.init.uniform_(x, -dev, dev)
|
||||
|
||||
return x
|
|
@ -1,19 +0,0 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def bias_sigmod_ele(y, bias, z):
|
||||
return torch.sigmoid(y + bias) * z
|
||||
|
||||
|
||||
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor,
|
||||
residual: torch.Tensor, prob: float) -> torch.Tensor:
|
||||
out = (x + bias) * F.dropout(dropmask, p=prob, training=False)
|
||||
out = residual + out
|
||||
return out
|
||||
|
||||
|
||||
def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor,
|
||||
dropout_mask: torch.Tensor, Z_raw: torch.Tensor,
|
||||
prob: float) -> torch.Tensor:
|
||||
return Z_raw + F.dropout(dropout_mask, p=prob, training=True) * (g * (ab + b))
|
|
@ -1,95 +0,0 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from .kernel import bias_dropout_add
|
||||
from .ops import SelfAttention, Transition
|
||||
|
||||
|
||||
class MSARowAttentionWithPairBias(nn.Module):
|
||||
|
||||
def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15):
|
||||
super(MSARowAttentionWithPairBias, self).__init__()
|
||||
self.d_node = d_node
|
||||
self.d_pair = d_pair
|
||||
self.c = c
|
||||
self.n_head = n_head
|
||||
self.p_drop = p_drop
|
||||
|
||||
self.layernormM = LayerNorm(d_node)
|
||||
self.layernormZ = LayerNorm(d_pair)
|
||||
|
||||
_init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]),
|
||||
std=1.0 / math.sqrt(d_pair))
|
||||
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True)
|
||||
|
||||
self.attention = SelfAttention(qkv_dim=d_node,
|
||||
c=c,
|
||||
n_head=n_head,
|
||||
out_dim=d_node,
|
||||
gating=True,
|
||||
last_bias_fuse=True)
|
||||
|
||||
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True)
|
||||
|
||||
def forward(self, M_raw, Z):
|
||||
## Input projections
|
||||
M = self.layernormM(M_raw)
|
||||
Z = self.layernormZ(Z)
|
||||
b = F.linear(Z, self.linear_b_weights)
|
||||
b = b.permute(0, 3, 1, 2)
|
||||
# b = rearrange(b, 'b q k h -> b h q k')
|
||||
|
||||
M = self.attention(M, b)
|
||||
dropout_mask = torch.ones_like(M[:, 0:1, :, :]).to(M.device).to(M.dtype)
|
||||
|
||||
return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop)
|
||||
|
||||
|
||||
class MSAColumnAttention(nn.Module):
|
||||
|
||||
def __init__(self, d_node, c=32, n_head=8):
|
||||
super(MSAColumnAttention, self).__init__()
|
||||
self.d_node = d_node
|
||||
self.c = c
|
||||
self.n_head = n_head
|
||||
|
||||
self.layernormM = LayerNorm(d_node)
|
||||
self.attention = SelfAttention(qkv_dim=d_node,
|
||||
c=c,
|
||||
n_head=n_head,
|
||||
out_dim=d_node,
|
||||
gating=True)
|
||||
|
||||
def forward(self, M_raw):
|
||||
M = M_raw.transpose(-2, -3)
|
||||
M = self.layernormM(M)
|
||||
|
||||
M = self.attention(M)
|
||||
|
||||
M = M.transpose(-2, -3)
|
||||
return M_raw + M
|
||||
|
||||
|
||||
class MSAStack(nn.Module):
|
||||
|
||||
def __init__(self, d_node, d_pair, p_drop=0.15):
|
||||
super(MSAStack, self).__init__()
|
||||
|
||||
self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node,
|
||||
d_pair=d_pair,
|
||||
p_drop=p_drop)
|
||||
|
||||
self.MSAColumnAttention = MSAColumnAttention(d_node=d_node)
|
||||
self.MSATransition = Transition(d=d_node)
|
||||
|
||||
def forward(self, node, pair):
|
||||
node = self.MSARowAttentionWithPairBias(node, pair)
|
||||
node = self.MSAColumnAttention(node)
|
||||
node = self.MSATransition(node)
|
||||
|
||||
return node
|
|
@ -1,176 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from .initializer import glorot_uniform_af
|
||||
from .kernel import bias_sigmod_ele
|
||||
|
||||
|
||||
class DropoutRowwise(nn.Module):
|
||||
|
||||
def __init__(self, p):
|
||||
super(DropoutRowwise, self).__init__()
|
||||
self.p = p
|
||||
self.dropout = nn.Dropout(p=p)
|
||||
|
||||
def forward(self, x):
|
||||
dropout_mask = torch.ones_like(x[:, 0:1, :, :])
|
||||
dropout_mask = self.dropout(dropout_mask)
|
||||
return dropout_mask * x
|
||||
|
||||
|
||||
class DropoutColumnwise(nn.Module):
|
||||
|
||||
def __init__(self, p):
|
||||
super(DropoutColumnwise, self).__init__()
|
||||
self.p = p
|
||||
self.dropout = nn.Dropout(p=p)
|
||||
|
||||
def forward(self, x):
|
||||
dropout_mask = torch.ones_like(x[:, :, 0:1, :])
|
||||
dropout_mask = self.dropout(dropout_mask)
|
||||
return dropout_mask * x
|
||||
|
||||
|
||||
class Transition(nn.Module):
|
||||
|
||||
def __init__(self, d, n=4):
|
||||
super(Transition, self).__init__()
|
||||
self.norm = LayerNorm(d)
|
||||
self.linear1 = Linear(d, n * d, initializer='relu')
|
||||
self.linear2 = Linear(n * d, d, initializer='zeros')
|
||||
|
||||
def forward(self, src):
|
||||
x = self.norm(src)
|
||||
x = self.linear2(F.relu(self.linear1(x)))
|
||||
return src + x
|
||||
|
||||
|
||||
class OutProductMean(nn.Module):
|
||||
|
||||
def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32):
|
||||
super(OutProductMean, self).__init__()
|
||||
|
||||
self.layernormM = LayerNorm(n_feat)
|
||||
self.linear_a = Linear(n_feat, n_feat_proj)
|
||||
self.linear_b = Linear(n_feat, n_feat_proj)
|
||||
|
||||
self.o_linear = Linear(n_feat_proj * n_feat_proj,
|
||||
n_feat_out,
|
||||
initializer='zero',
|
||||
use_bias=True)
|
||||
|
||||
def forward(self, M):
|
||||
M = self.layernormM(M)
|
||||
left_act = self.linear_a(M)
|
||||
right_act = self.linear_b(M)
|
||||
|
||||
o = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous()
|
||||
# O = rearrange(O, 'b i j d e -> b i j (d e)')
|
||||
o = o.reshape(o.shape[0], o.shape[1], o.shape[2], -1)
|
||||
Z = self.o_linear(o)
|
||||
|
||||
return Z
|
||||
|
||||
|
||||
class Linear(nn.Linear):
|
||||
"""
|
||||
A Linear layer with built-in nonstandard initializations. Called just
|
||||
like torch.nn.Linear.
|
||||
Implements the initializers in 1.11.4, plus some additional ones found
|
||||
in the code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_in: int,
|
||||
feature_out: int,
|
||||
initializer: str = 'linear',
|
||||
use_bias: bool = True,
|
||||
bias_init: float = 0.,
|
||||
):
|
||||
super(Linear, self).__init__(feature_in, feature_out, bias=use_bias)
|
||||
|
||||
self.use_bias = use_bias
|
||||
if initializer == 'linear':
|
||||
glorot_uniform_af(self.weight, gain=1.0)
|
||||
elif initializer == 'relu':
|
||||
glorot_uniform_af(self.weight, gain=2.0)
|
||||
elif initializer == 'zeros':
|
||||
nn.init.zeros_(self.weight)
|
||||
if self.use_bias:
|
||||
with torch.no_grad():
|
||||
self.bias.fill_(bias_init)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
"""
|
||||
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
|
||||
"""
|
||||
|
||||
def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False):
|
||||
super(SelfAttention, self).__init__()
|
||||
self.qkv_dim = qkv_dim
|
||||
self.c = c
|
||||
self.n_head = n_head
|
||||
self.out_dim = out_dim
|
||||
self.gating = gating
|
||||
self.last_bias_fuse = last_bias_fuse
|
||||
|
||||
self.scaling = self.c**(-0.5)
|
||||
|
||||
# self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear')
|
||||
self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
|
||||
self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
|
||||
self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
|
||||
|
||||
if gating:
|
||||
self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,)))
|
||||
self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False)
|
||||
|
||||
self.o_linear = Linear(n_head * c,
|
||||
out_dim,
|
||||
initializer='zero',
|
||||
use_bias=(not last_bias_fuse))
|
||||
|
||||
def forward(self, in_data, nonbatched_bias=None):
|
||||
"""
|
||||
:param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim]
|
||||
:param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv]
|
||||
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
|
||||
"""
|
||||
|
||||
# qkv = self.to_qkv(in_data).chunk(3, dim=-1)
|
||||
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
|
||||
|
||||
q = self.to_q(in_data)
|
||||
k = self.to_k(in_data)
|
||||
v = self.to_v(in_data)
|
||||
|
||||
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head),
|
||||
# [q, k, v])
|
||||
q, k, v = map(lambda t: t.view(t.shape[0], t.shape[1], t.shape[2], self.n_head, -1).permute(0, 1, 3, 2, 4),
|
||||
[q, k, v])
|
||||
|
||||
q = q * self.scaling
|
||||
|
||||
logits = torch.matmul(q, k.transpose(-1, -2))
|
||||
|
||||
if nonbatched_bias is not None:
|
||||
logits += nonbatched_bias.unsqueeze(1)
|
||||
weights = torch.softmax(logits, dim=-1)
|
||||
# weights = softmax(logits)
|
||||
|
||||
weighted_avg = torch.matmul(weights, v)
|
||||
# weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
|
||||
weighted_avg = weighted_avg.permute(0, 1, 3, 2, 4)
|
||||
weighted_avg = weighted_avg.reshape(weighted_avg.shape[0], weighted_avg.shape[1], weighted_avg.shape[2], -1)
|
||||
|
||||
if self.gating:
|
||||
gate_values = self.gating_linear(in_data)
|
||||
weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg)
|
||||
|
||||
output = self.o_linear(weighted_avg)
|
||||
return output
|
|
@ -1,192 +0,0 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from .kernel import bias_dropout_add, bias_ele_dropout_residual
|
||||
from .ops import Linear, SelfAttention, Transition
|
||||
|
||||
|
||||
def permute_final_dims(tensor, inds):
|
||||
zero_index = -1 * len(inds)
|
||||
first_inds = list(range(len(tensor.shape[:zero_index])))
|
||||
return tensor.permute(first_inds + [zero_index + i for i in inds])
|
||||
|
||||
|
||||
class TriangleMultiplicationOutgoing(nn.Module):
|
||||
|
||||
def __init__(self, d_pair, p_drop, c=128):
|
||||
super(TriangleMultiplicationOutgoing, self).__init__()
|
||||
self.d_pair = d_pair
|
||||
self.c = c
|
||||
|
||||
self.layernorm1 = LayerNorm(d_pair)
|
||||
self.left_projection = Linear(d_pair, c)
|
||||
self.right_projection = Linear(d_pair, c)
|
||||
self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||
self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||
|
||||
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
|
||||
self.layernorm2 = LayerNorm(c)
|
||||
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
|
||||
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||
|
||||
self.p_drop = p_drop
|
||||
|
||||
def forward(self, Z_raw):
|
||||
Z = self.layernorm1(Z_raw)
|
||||
left_proj_act = self.left_projection(Z)
|
||||
right_proj_act = self.right_projection(Z)
|
||||
|
||||
left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z))
|
||||
right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z))
|
||||
|
||||
g = torch.sigmoid(self.output_gate(Z))
|
||||
# p = torch.matmul(
|
||||
# permute_final_dims(left_proj_act, (2, 0, 1)),
|
||||
# permute_final_dims(right_proj_act, (2, 1, 0)),
|
||||
# )
|
||||
# ab = permute_final_dims(p, (1, 2, 0))
|
||||
|
||||
ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act)
|
||||
ab = self.output_projection(self.layernorm2(ab))
|
||||
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
|
||||
return bias_ele_dropout_residual(ab,
|
||||
self.output_bias,
|
||||
g,
|
||||
dropout_mask,
|
||||
Z_raw,
|
||||
prob=self.p_drop)
|
||||
|
||||
|
||||
class TriangleMultiplicationIncoming(nn.Module):
|
||||
|
||||
def __init__(self, d_pair, p_drop, c=128):
|
||||
super(TriangleMultiplicationIncoming, self).__init__()
|
||||
self.d_pair = d_pair
|
||||
self.c = c
|
||||
|
||||
self.layernorm1 = LayerNorm(d_pair)
|
||||
self.left_projection = Linear(d_pair, c)
|
||||
self.right_projection = Linear(d_pair, c)
|
||||
self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||
self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||
|
||||
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
|
||||
self.layernorm2 = LayerNorm(c)
|
||||
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
|
||||
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||
|
||||
self.p_drop = p_drop
|
||||
|
||||
def forward(self, Z_raw):
|
||||
Z = self.layernorm1(Z_raw)
|
||||
left_proj_act = self.left_projection(Z)
|
||||
right_proj_act = self.right_projection(Z)
|
||||
|
||||
left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z))
|
||||
right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z))
|
||||
|
||||
g = torch.sigmoid(self.output_gate(Z))
|
||||
# p = torch.matmul(
|
||||
# permute_final_dims(left_proj_act, (2, 1, 0)),
|
||||
# permute_final_dims(right_proj_act, (2, 0, 1)),
|
||||
# )
|
||||
# ab = permute_final_dims(p, (1, 2, 0))
|
||||
|
||||
ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act)
|
||||
ab = self.output_projection(self.layernorm2(ab))
|
||||
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
|
||||
return bias_ele_dropout_residual(ab,
|
||||
self.output_bias,
|
||||
g,
|
||||
dropout_mask,
|
||||
Z_raw,
|
||||
prob=self.p_drop)
|
||||
|
||||
|
||||
class TriangleAttentionStartingNode(nn.Module):
|
||||
|
||||
def __init__(self, d_pair, p_drop, c=32, n_head=4):
|
||||
super(TriangleAttentionStartingNode, self).__init__()
|
||||
self.d_pair = d_pair
|
||||
self.c = c
|
||||
self.n_head = n_head
|
||||
self.p_drop = p_drop
|
||||
|
||||
self.layernorm1 = LayerNorm(d_pair)
|
||||
_init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]),
|
||||
std=1.0 / math.sqrt(d_pair))
|
||||
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
|
||||
self.attention = SelfAttention(qkv_dim=d_pair,
|
||||
c=c,
|
||||
n_head=n_head,
|
||||
out_dim=d_pair,
|
||||
gating=True,
|
||||
last_bias_fuse=True)
|
||||
|
||||
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||
|
||||
def forward(self, Z_raw):
|
||||
Z = self.layernorm1(Z_raw)
|
||||
b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights)
|
||||
|
||||
Z = self.attention(Z, b)
|
||||
|
||||
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
|
||||
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
|
||||
|
||||
|
||||
class TriangleAttentionEndingNode(nn.Module):
|
||||
|
||||
def __init__(self, d_pair, p_drop, c=32, n_head=4):
|
||||
super(TriangleAttentionEndingNode, self).__init__()
|
||||
self.d_pair = d_pair
|
||||
self.c = c
|
||||
self.n_head = n_head
|
||||
self.p_drop = p_drop
|
||||
|
||||
self.layernorm1 = LayerNorm(d_pair)
|
||||
_init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]),
|
||||
std=1.0 / math.sqrt(d_pair))
|
||||
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
|
||||
self.attention = SelfAttention(qkv_dim=d_pair,
|
||||
c=c,
|
||||
n_head=n_head,
|
||||
out_dim=d_pair,
|
||||
gating=True,
|
||||
last_bias_fuse=True)
|
||||
|
||||
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||
|
||||
def forward(self, Z_raw):
|
||||
Z = Z_raw.transpose(-2, -3)
|
||||
Z = self.layernorm1(Z)
|
||||
b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights)
|
||||
|
||||
Z = self.attention(Z, b)
|
||||
|
||||
Z = Z.transpose(-2, -3)
|
||||
dropout_mask = torch.ones_like(Z[:, :, 0:1, :]).to(Z.device).to(Z.dtype)
|
||||
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
|
||||
|
||||
|
||||
class PairStack(nn.Module):
|
||||
|
||||
def __init__(self, d_pair, p_drop=0.25):
|
||||
super(PairStack, self).__init__()
|
||||
|
||||
self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, p_drop=p_drop)
|
||||
self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, p_drop=p_drop)
|
||||
self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, p_drop=p_drop)
|
||||
self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, p_drop=p_drop)
|
||||
self.PairTransition = Transition(d=d_pair)
|
||||
|
||||
def forward(self, pair):
|
||||
pair = self.TriangleMultiplicationOutgoing(pair)
|
||||
pair = self.TriangleMultiplicationIncoming(pair)
|
||||
pair = self.TriangleAttentionStartingNode(pair)
|
||||
pair = self.TriangleAttentionEndingNode(pair)
|
||||
pair = self.PairTransition(pair)
|
||||
return pair
|
|
@ -1,84 +0,0 @@
|
|||
# Copyright 2021 AlQuraishi Laboratory
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from typing import Any, Tuple, List, Callable, Optional
|
||||
|
||||
|
||||
BLOCK_ARG = Any
|
||||
BLOCK_ARGS = List[BLOCK_ARG]
|
||||
|
||||
|
||||
def get_checkpoint_fn():
|
||||
checkpoint = torch.utils.checkpoint.checkpoint
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def checkpoint_blocks(
|
||||
blocks: List[Callable],
|
||||
args: BLOCK_ARGS,
|
||||
blocks_per_ckpt: Optional[int],
|
||||
) -> BLOCK_ARGS:
|
||||
"""
|
||||
Chunk a list of blocks and run each chunk with activation
|
||||
checkpointing. We define a "block" as a callable whose only inputs are
|
||||
the outputs of the previous block.
|
||||
|
||||
Implements Subsection 1.11.8
|
||||
|
||||
Args:
|
||||
blocks:
|
||||
List of blocks
|
||||
args:
|
||||
Tuple of arguments for the first block.
|
||||
blocks_per_ckpt:
|
||||
Size of each chunk. A higher value corresponds to fewer
|
||||
checkpoints, and trades memory for speed. If None, no checkpointing
|
||||
is performed.
|
||||
Returns:
|
||||
The output of the final block
|
||||
"""
|
||||
def wrap(a):
|
||||
return (a,) if type(a) is not tuple else a
|
||||
|
||||
def exec(b, a):
|
||||
for block in b:
|
||||
a = wrap(block(*a))
|
||||
return a
|
||||
|
||||
def chunker(s, e):
|
||||
def exec_sliced(*a):
|
||||
return exec(blocks[s:e], a)
|
||||
|
||||
return exec_sliced
|
||||
|
||||
# Avoids mishaps when the blocks take just one argument
|
||||
args = wrap(args)
|
||||
|
||||
if blocks_per_ckpt is None:
|
||||
return exec(blocks, args)
|
||||
elif blocks_per_ckpt < 1 or blocks_per_ckpt > len(blocks):
|
||||
raise ValueError("blocks_per_ckpt must be between 1 and len(blocks)")
|
||||
|
||||
checkpoint = get_checkpoint_fn()
|
||||
|
||||
for s in range(0, len(blocks), blocks_per_ckpt):
|
||||
e = s + blocks_per_ckpt
|
||||
args = checkpoint(chunker(s, e), *args)
|
||||
args = wrap(args)
|
||||
|
||||
return args
|
|
@ -1,78 +0,0 @@
|
|||
# Copyright 2021 AlQuraishi Laboratory
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partialmethod
|
||||
from typing import Union, List
|
||||
|
||||
|
||||
class Dropout(nn.Module):
|
||||
"""
|
||||
Implementation of dropout with the ability to share the dropout mask
|
||||
along a particular dimension.
|
||||
|
||||
If not in training mode, this module computes the identity function.
|
||||
"""
|
||||
|
||||
def __init__(self, r: float, batch_dim: Union[int, List[int]]):
|
||||
"""
|
||||
Args:
|
||||
r:
|
||||
Dropout rate
|
||||
batch_dim:
|
||||
Dimension(s) along which the dropout mask is shared
|
||||
"""
|
||||
super(Dropout, self).__init__()
|
||||
|
||||
self.r = r
|
||||
if type(batch_dim) == int:
|
||||
batch_dim = [batch_dim]
|
||||
self.batch_dim = batch_dim
|
||||
self.dropout = nn.Dropout(self.r)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
Tensor to which dropout is applied. Can have any shape
|
||||
compatible with self.batch_dim
|
||||
"""
|
||||
shape = list(x.shape)
|
||||
if self.batch_dim is not None:
|
||||
for bd in self.batch_dim:
|
||||
shape[bd] = 1
|
||||
mask = x.new_ones(shape)
|
||||
mask = self.dropout(mask)
|
||||
x *= mask
|
||||
return x
|
||||
|
||||
|
||||
class DropoutRowwise(Dropout):
|
||||
"""
|
||||
Convenience class for rowwise dropout as described in subsection
|
||||
1.11.6.
|
||||
"""
|
||||
|
||||
__init__ = partialmethod(Dropout.__init__, batch_dim=-3)
|
||||
|
||||
|
||||
class DropoutColumnwise(Dropout):
|
||||
"""
|
||||
Convenience class for columnwise dropout as described in subsection
|
||||
1.11.6.
|
||||
"""
|
||||
|
||||
__init__ = partialmethod(Dropout.__init__, batch_dim=-2)
|
|
@ -1,431 +0,0 @@
|
|||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Tuple, Optional
|
||||
from functools import partial
|
||||
|
||||
from .primitives import Linear, LayerNorm
|
||||
from .dropout import DropoutRowwise, DropoutColumnwise
|
||||
from .msa import (
|
||||
MSARowAttentionWithPairBias,
|
||||
MSAColumnAttention,
|
||||
MSAColumnGlobalAttention,
|
||||
)
|
||||
from .outer_product_mean import OuterProductMean
|
||||
from .pair_transition import PairTransition
|
||||
from .triangular_attention import (
|
||||
TriangleAttentionStartingNode,
|
||||
TriangleAttentionEndingNode,
|
||||
)
|
||||
from .triangular_multiplicative_update import (
|
||||
TriangleMultiplicationOutgoing,
|
||||
TriangleMultiplicationIncoming,
|
||||
)
|
||||
from .checkpointing import checkpoint_blocks, get_checkpoint_fn
|
||||
from .tensor_utils import chunk_layer
|
||||
|
||||
|
||||
class MSATransition(nn.Module):
|
||||
"""
|
||||
Feed-forward network applied to MSA activations after attention.
|
||||
|
||||
Implements Algorithm 9
|
||||
"""
|
||||
def __init__(self, c_m, n):
|
||||
"""
|
||||
Args:
|
||||
c_m:
|
||||
MSA channel dimension
|
||||
n:
|
||||
Factor multiplied to c_m to obtain the hidden channel
|
||||
dimension
|
||||
"""
|
||||
super(MSATransition, self).__init__()
|
||||
|
||||
self.c_m = c_m
|
||||
self.n = n
|
||||
|
||||
self.layer_norm = LayerNorm(self.c_m)
|
||||
self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu")
|
||||
self.relu = nn.ReLU()
|
||||
self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
|
||||
|
||||
def _transition(self, m, mask):
|
||||
m = self.linear_1(m)
|
||||
m = self.relu(m)
|
||||
m = self.linear_2(m) * mask
|
||||
return m
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(self,
|
||||
m: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
chunk_size: int,
|
||||
) -> torch.Tensor:
|
||||
return chunk_layer(
|
||||
self._transition,
|
||||
{"m": m, "mask": mask},
|
||||
chunk_size=chunk_size,
|
||||
no_batch_dims=len(m.shape[:-2]),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
m: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
m:
|
||||
[*, N_seq, N_res, C_m] MSA activation
|
||||
mask:
|
||||
[*, N_seq, N_res, C_m] MSA mask
|
||||
Returns:
|
||||
m:
|
||||
[*, N_seq, N_res, C_m] MSA activation update
|
||||
"""
|
||||
|
||||
# DISCREPANCY: DeepMind forgets to apply the MSA mask here.
|
||||
if mask is None:
|
||||
mask = m.new_ones(m.shape[:-1])
|
||||
|
||||
# [*, N_seq, N_res, 1]
|
||||
mask = mask.unsqueeze(-1)
|
||||
|
||||
m = self.layer_norm(m)
|
||||
|
||||
if chunk_size is not None:
|
||||
m = self._chunk(m, mask, chunk_size)
|
||||
else:
|
||||
m = self._transition(m, mask)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
class EvoformerBlockCore(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
c_m: int,
|
||||
c_z: int,
|
||||
c_hidden_opm: int,
|
||||
c_hidden_mul: int,
|
||||
c_hidden_pair_att: int,
|
||||
no_heads_msa: int,
|
||||
no_heads_pair: int,
|
||||
transition_n: int,
|
||||
pair_dropout: float,
|
||||
inf: float,
|
||||
eps: float,
|
||||
_is_extra_msa_stack: bool = False,
|
||||
is_multimer: bool = False,
|
||||
):
|
||||
super(EvoformerBlockCore, self).__init__()
|
||||
self.is_multimer = is_multimer
|
||||
self.msa_transition = MSATransition(
|
||||
c_m=c_m,
|
||||
n=transition_n,
|
||||
)
|
||||
|
||||
self.outer_product_mean = OuterProductMean(
|
||||
c_m,
|
||||
c_z,
|
||||
c_hidden_opm,
|
||||
)
|
||||
|
||||
self.tri_mul_out = TriangleMultiplicationOutgoing(
|
||||
c_z,
|
||||
c_hidden_mul,
|
||||
)
|
||||
self.tri_mul_in = TriangleMultiplicationIncoming(
|
||||
c_z,
|
||||
c_hidden_mul,
|
||||
)
|
||||
|
||||
self.tri_att_start = TriangleAttentionStartingNode(
|
||||
c_z,
|
||||
c_hidden_pair_att,
|
||||
no_heads_pair,
|
||||
inf=inf,
|
||||
)
|
||||
self.tri_att_end = TriangleAttentionEndingNode(
|
||||
c_z,
|
||||
c_hidden_pair_att,
|
||||
no_heads_pair,
|
||||
inf=inf,
|
||||
)
|
||||
|
||||
self.pair_transition = PairTransition(
|
||||
c_z,
|
||||
transition_n,
|
||||
)
|
||||
|
||||
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
|
||||
self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
m: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# DeepMind doesn't mask these transitions in the source, so _mask_trans
|
||||
# should be disabled to better approximate the exact activations of
|
||||
# the original.
|
||||
|
||||
m = m + self.msa_transition(
|
||||
m, chunk_size=chunk_size
|
||||
)
|
||||
z = z + self.outer_product_mean(
|
||||
m, chunk_size=chunk_size
|
||||
)
|
||||
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z))
|
||||
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z))
|
||||
z = z + self.ps_dropout_row_layer(
|
||||
self.tri_att_start(z, chunk_size=chunk_size)
|
||||
)
|
||||
z = z + self.ps_dropout_col_layer(
|
||||
self.tri_att_end(z, chunk_size=chunk_size)
|
||||
)
|
||||
z = z + self.pair_transition(
|
||||
z, chunk_size=chunk_size
|
||||
)
|
||||
|
||||
return m, z
|
||||
|
||||
|
||||
class EvoformerBlock(nn.Module):
|
||||
def __init__(self,
|
||||
c_m: int,
|
||||
c_z: int,
|
||||
c_hidden_msa_att: int,
|
||||
c_hidden_opm: int,
|
||||
c_hidden_mul: int,
|
||||
c_hidden_pair_att: int,
|
||||
no_heads_msa: int,
|
||||
no_heads_pair: int,
|
||||
transition_n: int,
|
||||
msa_dropout: float,
|
||||
pair_dropout: float,
|
||||
inf: float,
|
||||
eps: float,
|
||||
is_multimer: bool,
|
||||
):
|
||||
super(EvoformerBlock, self).__init__()
|
||||
|
||||
self.msa_att_row = MSARowAttentionWithPairBias(
|
||||
c_m=c_m,
|
||||
c_z=c_z,
|
||||
c_hidden=c_hidden_msa_att,
|
||||
no_heads=no_heads_msa,
|
||||
inf=inf,
|
||||
)
|
||||
|
||||
self.msa_att_col = MSAColumnAttention(
|
||||
c_m,
|
||||
c_hidden_msa_att,
|
||||
no_heads_msa,
|
||||
inf=inf,
|
||||
)
|
||||
|
||||
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
|
||||
|
||||
self.core = EvoformerBlockCore(
|
||||
c_m=c_m,
|
||||
c_z=c_z,
|
||||
c_hidden_opm=c_hidden_opm,
|
||||
c_hidden_mul=c_hidden_mul,
|
||||
c_hidden_pair_att=c_hidden_pair_att,
|
||||
no_heads_msa=no_heads_msa,
|
||||
no_heads_pair=no_heads_pair,
|
||||
transition_n=transition_n,
|
||||
pair_dropout=pair_dropout,
|
||||
inf=inf,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
self.outer_product_mean = OuterProductMean(
|
||||
c_m,
|
||||
c_z,
|
||||
c_hidden_opm,
|
||||
)
|
||||
self.is_multimer = is_multimer
|
||||
|
||||
def forward(self,
|
||||
m: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
m = m + self.msa_dropout_layer(
|
||||
self.msa_att_row(m, z=z, chunk_size=chunk_size)
|
||||
)
|
||||
m = m + self.msa_att_col(m, chunk_size=chunk_size)
|
||||
m, z = self.core(
|
||||
m,
|
||||
z,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
return m, z
|
||||
|
||||
|
||||
class EvoformerStack(nn.Module):
|
||||
"""
|
||||
Main Evoformer trunk.
|
||||
|
||||
Implements Algorithm 6.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c_m: int,
|
||||
c_z: int,
|
||||
c_hidden_msa_att: int,
|
||||
c_hidden_opm: int,
|
||||
c_hidden_mul: int,
|
||||
c_hidden_pair_att: int,
|
||||
c_s: int,
|
||||
no_heads_msa: int,
|
||||
no_heads_pair: int,
|
||||
no_blocks: int,
|
||||
transition_n: int,
|
||||
msa_dropout: float,
|
||||
pair_dropout: float,
|
||||
blocks_per_ckpt: int,
|
||||
inf: float,
|
||||
eps: float,
|
||||
clear_cache_between_blocks: bool = False,
|
||||
is_multimer: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
c_m:
|
||||
MSA channel dimension
|
||||
c_z:
|
||||
Pair channel dimension
|
||||
c_hidden_msa_att:
|
||||
Hidden dimension in MSA attention
|
||||
c_hidden_opm:
|
||||
Hidden dimension in outer product mean module
|
||||
c_hidden_mul:
|
||||
Hidden dimension in multiplicative updates
|
||||
c_hidden_pair_att:
|
||||
Hidden dimension in triangular attention
|
||||
c_s:
|
||||
Channel dimension of the output "single" embedding
|
||||
no_heads_msa:
|
||||
Number of heads used for MSA attention
|
||||
no_heads_pair:
|
||||
Number of heads used for pair attention
|
||||
no_blocks:
|
||||
Number of Evoformer blocks in the stack
|
||||
transition_n:
|
||||
Factor by which to multiply c_m to obtain the MSATransition
|
||||
hidden dimension
|
||||
msa_dropout:
|
||||
Dropout rate for MSA activations
|
||||
pair_dropout:
|
||||
Dropout used for pair activations
|
||||
blocks_per_ckpt:
|
||||
Number of Evoformer blocks in each activation checkpoint
|
||||
clear_cache_between_blocks:
|
||||
Whether to clear CUDA's GPU memory cache between blocks of the
|
||||
stack. Slows down each block but can reduce fragmentation
|
||||
"""
|
||||
super(EvoformerStack, self).__init__()
|
||||
|
||||
self.blocks_per_ckpt = blocks_per_ckpt
|
||||
self.clear_cache_between_blocks = clear_cache_between_blocks
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
|
||||
for _ in range(no_blocks):
|
||||
block = EvoformerBlock(
|
||||
c_m=c_m,
|
||||
c_z=c_z,
|
||||
c_hidden_msa_att=c_hidden_msa_att,
|
||||
c_hidden_opm=c_hidden_opm,
|
||||
c_hidden_mul=c_hidden_mul,
|
||||
c_hidden_pair_att=c_hidden_pair_att,
|
||||
no_heads_msa=no_heads_msa,
|
||||
no_heads_pair=no_heads_pair,
|
||||
transition_n=transition_n,
|
||||
msa_dropout=msa_dropout,
|
||||
pair_dropout=pair_dropout,
|
||||
inf=inf,
|
||||
eps=eps,
|
||||
is_multimer=is_multimer,
|
||||
)
|
||||
self.blocks.append(block)
|
||||
|
||||
self.linear = Linear(c_m, c_s)
|
||||
|
||||
def forward(self,
|
||||
m: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
msa_mask: torch.Tensor,
|
||||
pair_mask: torch.Tensor,
|
||||
chunk_size: int,
|
||||
_mask_trans: bool = True,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
m:
|
||||
[*, N_seq, N_res, C_m] MSA embedding
|
||||
z:
|
||||
[*, N_res, N_res, C_z] pair embedding
|
||||
msa_mask:
|
||||
[*, N_seq, N_res] MSA mask
|
||||
pair_mask:
|
||||
[*, N_res, N_res] pair mask
|
||||
Returns:
|
||||
m:
|
||||
[*, N_seq, N_res, C_m] MSA embedding
|
||||
z:
|
||||
[*, N_res, N_res, C_z] pair embedding
|
||||
s:
|
||||
[*, N_res, C_s] single embedding (or None if extra MSA stack)
|
||||
"""
|
||||
blocks = [
|
||||
partial(
|
||||
b,
|
||||
msa_mask=msa_mask,
|
||||
pair_mask=pair_mask,
|
||||
chunk_size=chunk_size,
|
||||
_mask_trans=_mask_trans,
|
||||
)
|
||||
for b in self.blocks
|
||||
]
|
||||
|
||||
if(self.clear_cache_between_blocks):
|
||||
def block_with_cache_clear(block, *args):
|
||||
torch.cuda.empty_cache()
|
||||
return block(*args)
|
||||
|
||||
blocks = [partial(block_with_cache_clear, b) for b in blocks]
|
||||
|
||||
m, z = checkpoint_blocks(
|
||||
blocks,
|
||||
args=(m, z),
|
||||
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
|
||||
)
|
||||
|
||||
s = self.linear(m[..., 0, :, :])
|
||||
|
||||
return m, z, s
|
|
@ -1,331 +0,0 @@
|
|||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from .primitives import (
|
||||
Linear,
|
||||
LayerNorm,
|
||||
Attention,
|
||||
GlobalAttention,
|
||||
_attention_chunked_trainable,
|
||||
)
|
||||
from .checkpointing import get_checkpoint_fn
|
||||
from .tensor_utils import (
|
||||
chunk_layer,
|
||||
permute_final_dims,
|
||||
flatten_final_dims,
|
||||
)
|
||||
|
||||
|
||||
class MSAAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
c_in,
|
||||
c_hidden,
|
||||
no_heads,
|
||||
pair_bias=False,
|
||||
c_z=None,
|
||||
inf=1e9,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
c_in:
|
||||
Input channel dimension
|
||||
c_hidden:
|
||||
Per-head hidden channel dimension
|
||||
no_heads:
|
||||
Number of attention heads
|
||||
pair_bias:
|
||||
Whether to use pair embedding bias
|
||||
c_z:
|
||||
Pair embedding channel dimension. Ignored unless pair_bias
|
||||
is true
|
||||
inf:
|
||||
A large number to be used in computing the attention mask
|
||||
"""
|
||||
super(MSAAttention, self).__init__()
|
||||
|
||||
self.c_in = c_in
|
||||
self.c_hidden = c_hidden
|
||||
self.no_heads = no_heads
|
||||
self.pair_bias = pair_bias
|
||||
self.c_z = c_z
|
||||
self.inf = inf
|
||||
|
||||
self.layer_norm_m = LayerNorm(self.c_in)
|
||||
|
||||
self.layer_norm_z = None
|
||||
self.linear_z = None
|
||||
if self.pair_bias:
|
||||
self.layer_norm_z = LayerNorm(self.c_z)
|
||||
self.linear_z = Linear(
|
||||
self.c_z, self.no_heads, bias=False, init="normal"
|
||||
)
|
||||
|
||||
self.mha = Attention(
|
||||
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(self,
|
||||
m: torch.Tensor,
|
||||
biases: List[torch.Tensor],
|
||||
chunk_size: int,
|
||||
) -> torch.Tensor:
|
||||
return chunk_layer(
|
||||
self.mha,
|
||||
{"q_x": m, "kv_x": m, "biases": biases},
|
||||
chunk_size=chunk_size,
|
||||
no_batch_dims=len(m.shape[:-2]),
|
||||
)
|
||||
|
||||
def _prep_inputs(self,
|
||||
m: torch.Tensor,
|
||||
z: Optional[torch.Tensor],
|
||||
mask: Optional[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# [*, N_seq, N_res, C_m]
|
||||
m = self.layer_norm_m(m)
|
||||
|
||||
n_seq, n_res = m.shape[-3:-1]
|
||||
if mask is None:
|
||||
# [*, N_seq, N_res]
|
||||
mask = m.new_ones(
|
||||
m.shape[:-3] + (n_seq, n_res),
|
||||
)
|
||||
|
||||
# [*, N_seq, 1, 1, N_res]
|
||||
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
|
||||
|
||||
# This step simply returns a larger view of the bias, and does not
|
||||
# consume additional memory.
|
||||
# [*, N_seq, no_heads, N_res, N_res]
|
||||
#bias = bias.expand(
|
||||
# ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
|
||||
#)
|
||||
|
||||
if (self.pair_bias and
|
||||
z is not None and # For the
|
||||
self.layer_norm_z is not None and # benefit of
|
||||
self.linear_z is not None # TorchScript
|
||||
):
|
||||
# [*, N_res, N_res, C_z]
|
||||
z = self.layer_norm_z(z)
|
||||
|
||||
# [*, N_res, N_res, no_heads]
|
||||
z = self.linear_z(z)
|
||||
|
||||
# [*, 1, no_heads, N_res, N_res]
|
||||
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
|
||||
|
||||
return m, mask_bias, z
|
||||
|
||||
|
||||
def forward(self,
|
||||
m: torch.Tensor,
|
||||
z: Optional[torch.Tensor] = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
_chunk_logits: Optional[int] = None,
|
||||
_checkpoint_chunks: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
m:
|
||||
[*, N_seq, N_res, C_m] MSA embedding
|
||||
z:
|
||||
[*, N_res, N_res, C_z] pair embedding. Required only if
|
||||
pair_bias is True
|
||||
mask:
|
||||
[*, N_seq, N_res] MSA mask
|
||||
chunk_size:
|
||||
Size of chunks into which the inputs are split along their
|
||||
batch dimensions. A low value decreases memory overhead at the
|
||||
cost of slower execution. Chunking is not performed by default.
|
||||
|
||||
"""
|
||||
m, mask_bias, z = self._prep_inputs(m, z, mask)
|
||||
|
||||
biases = [mask_bias]
|
||||
if(z is not None):
|
||||
biases.append(z)
|
||||
|
||||
if chunk_size is not None:
|
||||
m = self._chunk(m, biases, chunk_size)
|
||||
else:
|
||||
m = self.mha(
|
||||
q_x=m,
|
||||
kv_x=m,
|
||||
biases=biases
|
||||
)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
class MSARowAttentionWithPairBias(MSAAttention):
|
||||
"""
|
||||
Implements Algorithm 7.
|
||||
"""
|
||||
|
||||
def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9):
|
||||
"""
|
||||
Args:
|
||||
c_m:
|
||||
Input channel dimension
|
||||
c_z:
|
||||
Pair embedding channel dimension
|
||||
c_hidden:
|
||||
Per-head hidden channel dimension
|
||||
no_heads:
|
||||
Number of attention heads
|
||||
inf:
|
||||
Large number used to construct attention masks
|
||||
"""
|
||||
super(MSARowAttentionWithPairBias, self).__init__(
|
||||
c_m,
|
||||
c_hidden,
|
||||
no_heads,
|
||||
pair_bias=True,
|
||||
c_z=c_z,
|
||||
inf=inf,
|
||||
)
|
||||
|
||||
|
||||
class MSAColumnAttention(nn.Module):
|
||||
"""
|
||||
Implements Algorithm 8.
|
||||
|
||||
By rights, this should also be a subclass of MSAAttention. Alas,
|
||||
most inheritance isn't supported by TorchScript.
|
||||
"""
|
||||
|
||||
def __init__(self, c_m, c_hidden, no_heads, inf=1e9):
|
||||
"""
|
||||
Args:
|
||||
c_m:
|
||||
MSA channel dimension
|
||||
c_hidden:
|
||||
Per-head hidden channel dimension
|
||||
no_heads:
|
||||
Number of attention heads
|
||||
inf:
|
||||
Large number used to construct attention masks
|
||||
"""
|
||||
super(MSAColumnAttention, self).__init__()
|
||||
|
||||
self.c_m = c_m
|
||||
self.c_hidden = c_hidden
|
||||
self.no_heads = no_heads
|
||||
self.inf = inf
|
||||
|
||||
self._msa_att = MSAAttention(
|
||||
c_in=c_m,
|
||||
c_hidden=c_hidden,
|
||||
no_heads=no_heads,
|
||||
pair_bias=False,
|
||||
c_z=None,
|
||||
inf=inf,
|
||||
)
|
||||
|
||||
def forward(self,
|
||||
m: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
m:
|
||||
[*, N_seq, N_res, C_m] MSA embedding
|
||||
mask:
|
||||
[*, N_seq, N_res] MSA mask
|
||||
chunk_size:
|
||||
Size of chunks into which the inputs are split along their
|
||||
batch dimensions. A low value decreases memory overhead at the
|
||||
cost of slower execution. Chunking is not performed by default.
|
||||
"""
|
||||
# [*, N_res, N_seq, C_in]
|
||||
m = m.transpose(-2, -3)
|
||||
|
||||
m = self._msa_att(m, chunk_size=chunk_size)
|
||||
|
||||
# [*, N_seq, N_res, C_in]
|
||||
m = m.transpose(-2, -3)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
class MSAColumnGlobalAttention(nn.Module):
|
||||
def __init__(
|
||||
self, c_in, c_hidden, no_heads, inf=1e9, eps=1e-10,
|
||||
):
|
||||
super(MSAColumnGlobalAttention, self).__init__()
|
||||
|
||||
self.c_in = c_in
|
||||
self.c_hidden = c_hidden
|
||||
self.no_heads = no_heads
|
||||
self.inf = inf
|
||||
self.eps = eps
|
||||
|
||||
self.layer_norm_m = nn.LayerNorm(c_in)
|
||||
|
||||
self.global_attention = GlobalAttention(
|
||||
c_in=c_in,
|
||||
c_hidden=c_hidden,
|
||||
no_heads=no_heads,
|
||||
inf=inf,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(self,
|
||||
m: torch.Tensor,
|
||||
chunk_size: int,
|
||||
) -> torch.Tensor:
|
||||
mha_input = {
|
||||
"m": m,
|
||||
}
|
||||
return chunk_layer(
|
||||
self.global_attention,
|
||||
mha_input,
|
||||
chunk_size=chunk_size,
|
||||
no_batch_dims=len(m.shape[:-2]),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
m: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
n_seq, n_res, c_in = m.shape[-3:]
|
||||
|
||||
# [*, N_res, N_seq, C_in]
|
||||
m = m.transpose(-2, -3)
|
||||
|
||||
# [*, N_res, N_seq, C_in]
|
||||
m = self.layer_norm_m(m)
|
||||
|
||||
if chunk_size is not None:
|
||||
m = self._chunk(m, chunk_size)
|
||||
else:
|
||||
m = self.global_attention(m=m)
|
||||
|
||||
# [*, N_seq, N_res, C_in]
|
||||
m = m.transpose(-2, -3)
|
||||
|
||||
return m
|
|
@ -1,129 +0,0 @@
|
|||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .primitives import Linear
|
||||
from .tensor_utils import chunk_layer
|
||||
|
||||
|
||||
class OuterProductMean(nn.Module):
|
||||
"""
|
||||
Implements Algorithm 10.
|
||||
"""
|
||||
|
||||
def __init__(self, c_m, c_z, c_hidden, eps=1e-3):
|
||||
"""
|
||||
Args:
|
||||
c_m:
|
||||
MSA embedding channel dimension
|
||||
c_z:
|
||||
Pair embedding channel dimension
|
||||
c_hidden:
|
||||
Hidden channel dimension
|
||||
"""
|
||||
super(OuterProductMean, self).__init__()
|
||||
|
||||
self.c_m = c_m
|
||||
self.c_z = c_z
|
||||
self.c_hidden = c_hidden
|
||||
self.eps = eps
|
||||
|
||||
self.layer_norm = nn.LayerNorm(c_m)
|
||||
self.linear_1 = Linear(c_m, c_hidden)
|
||||
self.linear_2 = Linear(c_m, c_hidden)
|
||||
self.linear_out = Linear(c_hidden ** 2, c_z, init="final")
|
||||
|
||||
def _opm(self, a, b):
|
||||
# [*, N_res, N_res, C, C]
|
||||
outer = torch.einsum("...bac,...dae->...bdce", a, b)
|
||||
|
||||
# [*, N_res, N_res, C * C]
|
||||
outer = outer.reshape(outer.shape[:-2] + (-1,))
|
||||
|
||||
# [*, N_res, N_res, C_z]
|
||||
outer = self.linear_out(outer)
|
||||
|
||||
return outer
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(self,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
chunk_size: int
|
||||
) -> torch.Tensor:
|
||||
# Since the "batch dim" in this case is not a true batch dimension
|
||||
# (in that the shape of the output depends on it), we need to
|
||||
# iterate over it ourselves
|
||||
a_reshape = a.reshape((-1,) + a.shape[-3:])
|
||||
b_reshape = b.reshape((-1,) + b.shape[-3:])
|
||||
out = []
|
||||
for a_prime, b_prime in zip(a_reshape, b_reshape):
|
||||
outer = chunk_layer(
|
||||
partial(self._opm, b=b_prime),
|
||||
{"a": a_prime},
|
||||
chunk_size=chunk_size,
|
||||
no_batch_dims=1,
|
||||
)
|
||||
out.append(outer)
|
||||
outer = torch.stack(out, dim=0)
|
||||
outer = outer.reshape(a.shape[:-3] + outer.shape[1:])
|
||||
|
||||
return outer
|
||||
|
||||
def forward(self,
|
||||
m: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
m:
|
||||
[*, N_seq, N_res, C_m] MSA embedding
|
||||
mask:
|
||||
[*, N_seq, N_res] MSA mask
|
||||
Returns:
|
||||
[*, N_res, N_res, C_z] pair embedding update
|
||||
"""
|
||||
if mask is None:
|
||||
mask = m.new_ones(m.shape[:-1])
|
||||
|
||||
# [*, N_seq, N_res, C_m]
|
||||
m = self.layer_norm(m)
|
||||
|
||||
# [*, N_seq, N_res, C]
|
||||
mask = mask.unsqueeze(-1)
|
||||
a = self.linear_1(m) * mask
|
||||
b = self.linear_2(m) * mask
|
||||
|
||||
a = a.transpose(-2, -3)
|
||||
b = b.transpose(-2, -3)
|
||||
|
||||
if chunk_size is not None:
|
||||
outer = self._chunk(a, b, chunk_size)
|
||||
else:
|
||||
outer = self._opm(a, b)
|
||||
|
||||
# [*, N_res, N_res, 1]
|
||||
norm = torch.einsum("...abc,...adc->...bdc", mask, mask)
|
||||
|
||||
# [*, N_res, N_res, C_z]
|
||||
outer = outer / (self.eps + norm)
|
||||
|
||||
return outer
|
|
@ -1,99 +0,0 @@
|
|||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .primitives import Linear, LayerNorm
|
||||
from .tensor_utils import chunk_layer
|
||||
|
||||
|
||||
class PairTransition(nn.Module):
|
||||
"""
|
||||
Implements Algorithm 15.
|
||||
"""
|
||||
|
||||
def __init__(self, c_z, n):
|
||||
"""
|
||||
Args:
|
||||
c_z:
|
||||
Pair transition channel dimension
|
||||
n:
|
||||
Factor by which c_z is multiplied to obtain hidden channel
|
||||
dimension
|
||||
"""
|
||||
super(PairTransition, self).__init__()
|
||||
|
||||
self.c_z = c_z
|
||||
self.n = n
|
||||
|
||||
self.layer_norm = LayerNorm(self.c_z)
|
||||
self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu")
|
||||
self.relu = nn.ReLU()
|
||||
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
|
||||
|
||||
def _transition(self, z, mask):
|
||||
# [*, N_res, N_res, C_hidden]
|
||||
z = self.linear_1(z)
|
||||
z = self.relu(z)
|
||||
|
||||
# [*, N_res, N_res, C_z]
|
||||
z = self.linear_2(z) * mask
|
||||
|
||||
return z
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(self,
|
||||
z: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
chunk_size: int,
|
||||
) -> torch.Tensor:
|
||||
return chunk_layer(
|
||||
self._transition,
|
||||
{"z": z, "mask": mask},
|
||||
chunk_size=chunk_size,
|
||||
no_batch_dims=len(z.shape[:-2]),
|
||||
)
|
||||
|
||||
|
||||
def forward(self,
|
||||
z: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
z:
|
||||
[*, N_res, N_res, C_z] pair embedding
|
||||
Returns:
|
||||
[*, N_res, N_res, C_z] pair embedding update
|
||||
"""
|
||||
# DISCREPANCY: DeepMind forgets to apply the mask in this module.
|
||||
if mask is None:
|
||||
mask = z.new_ones(z.shape[:-1])
|
||||
|
||||
# [*, N_res, N_res, 1]
|
||||
mask = mask.unsqueeze(-1)
|
||||
|
||||
# [*, N_res, N_res, C_z]
|
||||
z = self.layer_norm(z)
|
||||
|
||||
if chunk_size is not None:
|
||||
z = self._chunk(z, mask, chunk_size)
|
||||
else:
|
||||
z = self._transition(z=z, mask=mask)
|
||||
|
||||
return z
|
|
@ -1,529 +0,0 @@
|
|||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import Optional, Callable, List, Tuple, Sequence
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .checkpointing import get_checkpoint_fn
|
||||
from .tensor_utils import (
|
||||
permute_final_dims,
|
||||
flatten_final_dims,
|
||||
_chunk_slice,
|
||||
)
|
||||
|
||||
|
||||
def _prod(nums):
|
||||
out = 1
|
||||
for n in nums:
|
||||
out = out * n
|
||||
return out
|
||||
|
||||
|
||||
def _calculate_fan(linear_weight_shape, fan="fan_in"):
|
||||
fan_out, fan_in = linear_weight_shape
|
||||
|
||||
if fan == "fan_in":
|
||||
f = fan_in
|
||||
elif fan == "fan_out":
|
||||
f = fan_out
|
||||
elif fan == "fan_avg":
|
||||
f = (fan_in + fan_out) / 2
|
||||
else:
|
||||
raise ValueError("Invalid fan option")
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def glorot_uniform_init_(weights):
|
||||
nn.init.xavier_uniform_(weights, gain=1)
|
||||
|
||||
|
||||
def final_init_(weights):
|
||||
with torch.no_grad():
|
||||
weights.fill_(0.0)
|
||||
|
||||
|
||||
def gating_init_(weights):
|
||||
with torch.no_grad():
|
||||
weights.fill_(0.0)
|
||||
|
||||
|
||||
def normal_init_(weights):
|
||||
torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
|
||||
|
||||
|
||||
def ipa_point_weights_init_(weights):
|
||||
with torch.no_grad():
|
||||
softplus_inverse_1 = 0.541324854612918
|
||||
weights.fill_(softplus_inverse_1)
|
||||
|
||||
|
||||
class Linear(nn.Linear):
|
||||
"""
|
||||
A Linear layer with built-in nonstandard initializations. Called just
|
||||
like torch.nn.Linear.
|
||||
|
||||
Implements the initializers in 1.11.4, plus some additional ones found
|
||||
in the code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
out_dim: int,
|
||||
bias: bool = True,
|
||||
init: str = "default",
|
||||
init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
in_dim:
|
||||
The final dimension of inputs to the layer
|
||||
out_dim:
|
||||
The final dimension of layer outputs
|
||||
bias:
|
||||
Whether to learn an additive bias. True by default
|
||||
init:
|
||||
The initializer to use. Choose from:
|
||||
|
||||
"default": LeCun fan-in truncated normal initialization
|
||||
"relu": He initialization w/ truncated normal distribution
|
||||
"glorot": Fan-average Glorot uniform initialization
|
||||
"gating": Weights=0, Bias=1
|
||||
"normal": Normal initialization with std=1/sqrt(fan_in)
|
||||
"final": Weights=0, Bias=0
|
||||
|
||||
Overridden by init_fn if the latter is not None.
|
||||
init_fn:
|
||||
A custom initializer taking weight and bias as inputs.
|
||||
Overrides init if not None.
|
||||
"""
|
||||
super(Linear, self).__init__(in_dim, out_dim, bias=bias)
|
||||
|
||||
if bias:
|
||||
with torch.no_grad():
|
||||
self.bias.fill_(0)
|
||||
|
||||
if init_fn is not None:
|
||||
init_fn(self.weight, self.bias)
|
||||
else:
|
||||
if init == "default":
|
||||
normal_init_(self.weight)
|
||||
elif init == "relu":
|
||||
normal_init_(self.weight)
|
||||
elif init == "glorot":
|
||||
glorot_uniform_init_(self.weight)
|
||||
elif init == "gating":
|
||||
gating_init_(self.weight)
|
||||
if bias:
|
||||
with torch.no_grad():
|
||||
self.bias.fill_(1.0)
|
||||
elif init == "normal":
|
||||
normal_init_(self.weight)
|
||||
elif init == "final":
|
||||
final_init_(self.weight)
|
||||
else:
|
||||
raise ValueError("Invalid init string.")
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
|
||||
def __init__(self, c_in, eps=1e-5):
|
||||
super(LayerNorm, self).__init__()
|
||||
|
||||
self.c_in = (c_in,)
|
||||
self.eps = eps
|
||||
|
||||
self.weight = nn.Parameter(torch.ones(c_in))
|
||||
self.bias = nn.Parameter(torch.zeros(c_in))
|
||||
|
||||
def forward(self, x):
|
||||
out = nn.functional.layer_norm(
|
||||
x,
|
||||
self.c_in,
|
||||
self.weight,
|
||||
self.bias,
|
||||
self.eps,
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
"""
|
||||
Softmax, but without automatic casting to fp32 when the input is of
|
||||
type bfloat16
|
||||
"""
|
||||
s = torch.nn.functional.softmax(t, dim=dim)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
#@torch.jit.script
|
||||
def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
|
||||
biases: List[torch.Tensor]) -> torch.Tensor:
|
||||
# [*, H, Q, C_hidden]
|
||||
query = permute_final_dims(query, (1, 0, 2))
|
||||
|
||||
# [*, H, C_hidden, K]
|
||||
key = permute_final_dims(key, (1, 2, 0))
|
||||
|
||||
# [*, H, V, C_hidden]
|
||||
value = permute_final_dims(value, (1, 0, 2))
|
||||
|
||||
# [*, H, Q, K]
|
||||
a = torch.matmul(query, key)
|
||||
|
||||
for b in biases:
|
||||
a += b
|
||||
|
||||
a = softmax(a, -1)
|
||||
|
||||
# [*, H, Q, C_hidden]
|
||||
a = torch.matmul(a, value)
|
||||
|
||||
# [*, Q, H, C_hidden]
|
||||
a = a.transpose(-2, -3)
|
||||
|
||||
return a
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _attention_chunked_trainable(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
biases,
|
||||
chunk_size,
|
||||
chunk_dim,
|
||||
checkpoint,
|
||||
):
|
||||
if (checkpoint and len(biases) > 2):
|
||||
raise ValueError("Checkpointed version permits only permits two bias terms")
|
||||
|
||||
def _checkpointable_attention(q, k, v, b1, b2):
|
||||
bs = [b for b in [b1, b2] if b is not None]
|
||||
return _attention(q, k, v, bs)
|
||||
|
||||
o_chunks = []
|
||||
checkpoint_fn = get_checkpoint_fn()
|
||||
count = query.shape[chunk_dim]
|
||||
for start in range(0, count, chunk_size):
|
||||
end = start + chunk_size
|
||||
idx = [slice(None)] * len(query.shape)
|
||||
idx[chunk_dim] = slice(start, end)
|
||||
idx_tup = tuple(idx)
|
||||
q_chunk = query[idx_tup]
|
||||
k_chunk = key[idx_tup]
|
||||
v_chunk = value[idx_tup]
|
||||
|
||||
def _slice_bias(b):
|
||||
idx[chunk_dim] = (slice(start, end) if b.shape[chunk_dim] != 1 else slice(None))
|
||||
return b[tuple(idx)]
|
||||
|
||||
if (checkpoint):
|
||||
bias_1_chunk, bias_2_chunk = [
|
||||
_slice_bias(b) if b is not None else None for b in (biases + [None, None])[:2]
|
||||
]
|
||||
|
||||
o_chunk = checkpoint_fn(_checkpointable_attention, q_chunk, k_chunk, v_chunk,
|
||||
bias_1_chunk, bias_2_chunk)
|
||||
else:
|
||||
bias_chunks = [_slice_bias(b) for b in biases]
|
||||
|
||||
o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks)
|
||||
|
||||
o_chunks.append(o_chunk)
|
||||
|
||||
o = torch.cat(o_chunks, dim=chunk_dim)
|
||||
return o
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
Standard multi-head attention using AlphaFold's default layer
|
||||
initialization. Allows multiple bias vectors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
c_q: int,
|
||||
c_k: int,
|
||||
c_v: int,
|
||||
c_hidden: int,
|
||||
no_heads: int,
|
||||
gating: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
c_q:
|
||||
Input dimension of query data
|
||||
c_k:
|
||||
Input dimension of key data
|
||||
c_v:
|
||||
Input dimension of value data
|
||||
c_hidden:
|
||||
Per-head hidden dimension
|
||||
no_heads:
|
||||
Number of attention heads
|
||||
gating:
|
||||
Whether the output should be gated using query data
|
||||
"""
|
||||
super(Attention, self).__init__()
|
||||
|
||||
self.c_q = c_q
|
||||
self.c_k = c_k
|
||||
self.c_v = c_v
|
||||
self.c_hidden = c_hidden
|
||||
self.no_heads = no_heads
|
||||
self.gating = gating
|
||||
|
||||
# DISCREPANCY: c_hidden is not the per-head channel dimension, as
|
||||
# stated in the supplement, but the overall channel dimension.
|
||||
|
||||
self.linear_q = Linear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot")
|
||||
self.linear_k = Linear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot")
|
||||
self.linear_v = Linear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot")
|
||||
self.linear_o = Linear(self.c_hidden * self.no_heads, self.c_q, init="final")
|
||||
|
||||
self.linear_g = None
|
||||
if self.gating:
|
||||
self.linear_g = Linear(self.c_q, self.c_hidden * self.no_heads, init="gating")
|
||||
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def _prep_qkv(self, q_x: torch.Tensor,
|
||||
kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
# [*, Q/K/V, H * C_hidden]
|
||||
q = self.linear_q(q_x)
|
||||
k = self.linear_k(kv_x)
|
||||
v = self.linear_v(kv_x)
|
||||
|
||||
# [*, Q/K, H, C_hidden]
|
||||
q = q.view(q.shape[:-1] + (self.no_heads, -1))
|
||||
k = k.view(k.shape[:-1] + (self.no_heads, -1))
|
||||
v = v.view(v.shape[:-1] + (self.no_heads, -1))
|
||||
|
||||
q /= math.sqrt(self.c_hidden)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor:
|
||||
if (self.linear_g is not None):
|
||||
g = self.sigmoid(self.linear_g(q_x))
|
||||
|
||||
# [*, Q, H, C_hidden]
|
||||
g = g.view(g.shape[:-1] + (self.no_heads, -1))
|
||||
o = o * g
|
||||
|
||||
# [*, Q, H * C_hidden]
|
||||
o = flatten_final_dims(o, 2)
|
||||
|
||||
# [*, Q, C_q]
|
||||
o = self.linear_o(o)
|
||||
|
||||
return o
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q_x: torch.Tensor,
|
||||
kv_x: torch.Tensor,
|
||||
biases: Optional[List[torch.Tensor]] = None,
|
||||
use_lma: bool = False,
|
||||
q_chunk_size: Optional[int] = None,
|
||||
kv_chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
q_x:
|
||||
[*, Q, C_q] query data
|
||||
kv_x:
|
||||
[*, K, C_k] key data
|
||||
biases:
|
||||
List of biases that broadcast to [*, H, Q, K]
|
||||
use_lma:
|
||||
Whether to use low-memory attention
|
||||
q_chunk_size:
|
||||
Query chunk size (for LMA)
|
||||
kv_chunk_size:
|
||||
Key/Value chunk size (for LMA)
|
||||
Returns
|
||||
[*, Q, C_q] attention update
|
||||
"""
|
||||
if (biases is None):
|
||||
biases = []
|
||||
if (use_lma and (q_chunk_size is None or kv_chunk_size is None)):
|
||||
raise ValueError("If use_lma is specified, q_chunk_size and kv_chunk_size must "
|
||||
"be provided")
|
||||
|
||||
q, k, v = self._prep_qkv(q_x, kv_x)
|
||||
|
||||
if (use_lma):
|
||||
biases = [b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) for b in biases]
|
||||
|
||||
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
|
||||
else:
|
||||
o = _attention(q, k, v, biases)
|
||||
|
||||
o = self._wrap_up(o, q_x)
|
||||
|
||||
return o
|
||||
|
||||
|
||||
class GlobalAttention(nn.Module):
|
||||
|
||||
def __init__(self, c_in, c_hidden, no_heads, inf, eps):
|
||||
super(GlobalAttention, self).__init__()
|
||||
|
||||
self.c_in = c_in
|
||||
self.c_hidden = c_hidden
|
||||
self.no_heads = no_heads
|
||||
self.inf = inf
|
||||
self.eps = eps
|
||||
|
||||
self.linear_q = Linear(c_in, c_hidden * no_heads, bias=False, init="glorot")
|
||||
|
||||
self.linear_k = Linear(
|
||||
c_in,
|
||||
c_hidden,
|
||||
bias=False,
|
||||
init="glorot",
|
||||
)
|
||||
self.linear_v = Linear(
|
||||
c_in,
|
||||
c_hidden,
|
||||
bias=False,
|
||||
init="glorot",
|
||||
)
|
||||
self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating")
|
||||
self.linear_o = Linear(c_hidden * no_heads, c_in, init="final")
|
||||
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
||||
# [*, N_res, C_in]
|
||||
q = torch.sum(m * mask.unsqueeze(-1),
|
||||
dim=-2) / (torch.sum(mask, dim=-1)[..., None] + self.eps)
|
||||
|
||||
# [*, N_res, H * C_hidden]
|
||||
q = self.linear_q(q)
|
||||
q *= (self.c_hidden**(-0.5))
|
||||
|
||||
# [*, N_res, H, C_hidden]
|
||||
q = q.view(q.shape[:-1] + (self.no_heads, -1))
|
||||
|
||||
# [*, N_res, N_seq, C_hidden]
|
||||
k = self.linear_k(m)
|
||||
v = self.linear_v(m)
|
||||
|
||||
# [*, N_res, H, N_seq]
|
||||
a = torch.matmul(
|
||||
q,
|
||||
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
|
||||
)
|
||||
bias = (self.inf * (mask - 1))[..., :, None, :]
|
||||
a += bias
|
||||
a = softmax(a)
|
||||
|
||||
# [*, N_res, H, C_hidden]
|
||||
o = torch.matmul(
|
||||
a,
|
||||
v,
|
||||
)
|
||||
|
||||
# [*, N_res, N_seq, C_hidden]
|
||||
g = self.sigmoid(self.linear_g(m))
|
||||
|
||||
# [*, N_res, N_seq, H, C_hidden]
|
||||
g = g.view(g.shape[:-1] + (self.no_heads, -1))
|
||||
|
||||
# [*, N_res, N_seq, H, C_hidden]
|
||||
o = o.unsqueeze(-3) * g
|
||||
|
||||
# [*, N_res, N_seq, H * C_hidden]
|
||||
o = o.reshape(o.shape[:-2] + (-1,))
|
||||
|
||||
# [*, N_res, N_seq, C_in]
|
||||
m = self.linear_o(o)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
def _lma(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
biases: List[torch.Tensor],
|
||||
q_chunk_size: int,
|
||||
kv_chunk_size: int,
|
||||
):
|
||||
no_q, no_kv = q.shape[-3], k.shape[-3]
|
||||
|
||||
# [*, Q, H, C_hidden]
|
||||
o = q.new_zeros(q.shape)
|
||||
for q_s in range(0, no_q, q_chunk_size):
|
||||
q_chunk = q[..., q_s:q_s + q_chunk_size, :, :]
|
||||
large_bias_chunks = [b[..., q_s:q_s + q_chunk_size, :] for b in biases]
|
||||
|
||||
maxes = []
|
||||
weights = []
|
||||
values = []
|
||||
for kv_s in range(0, no_kv, kv_chunk_size):
|
||||
k_chunk = k[..., kv_s:kv_s + kv_chunk_size, :, :]
|
||||
v_chunk = v[..., kv_s:kv_s + kv_chunk_size, :, :]
|
||||
small_bias_chunks = [b[..., kv_s:kv_s + kv_chunk_size] for b in large_bias_chunks]
|
||||
|
||||
a = torch.einsum(
|
||||
"...qhd,...khd->...hqk",
|
||||
q_chunk,
|
||||
k_chunk,
|
||||
)
|
||||
|
||||
for b in small_bias_chunks:
|
||||
a += b
|
||||
|
||||
a = a.transpose(-2, -3)
|
||||
|
||||
max_a = torch.max(a, dim=-1, keepdim=True)[0]
|
||||
exp_a = torch.exp(a - max_a)
|
||||
exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a)
|
||||
|
||||
maxes.append(max_a.detach().squeeze(-1))
|
||||
weights.append(torch.sum(exp_a, dim=-1))
|
||||
values.append(exp_v)
|
||||
|
||||
chunk_max = torch.stack(maxes, dim=-3)
|
||||
chunk_weights = torch.stack(weights, dim=-3)
|
||||
chunk_values = torch.stack(values, dim=-4)
|
||||
|
||||
global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0]
|
||||
max_diffs = torch.exp(chunk_max - global_max)
|
||||
chunk_values *= max_diffs.unsqueeze(-1)
|
||||
chunk_weights *= max_diffs
|
||||
|
||||
all_values = torch.sum(chunk_values, dim=-4)
|
||||
all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4)
|
||||
|
||||
q_chunk_out = all_values / all_weights
|
||||
|
||||
o[..., q_s:q_s + q_chunk_size, :, :] = q_chunk_out
|
||||
|
||||
return o
|
|
@ -1,408 +0,0 @@
|
|||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional
|
||||
|
||||
|
||||
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
|
||||
zero_index = -1 * len(inds)
|
||||
first_inds = list(range(len(tensor.shape[:zero_index])))
|
||||
return tensor.permute(first_inds + [zero_index + i for i in inds])
|
||||
|
||||
|
||||
def flatten_final_dims(t: torch.Tensor, no_dims: int):
|
||||
return t.reshape(t.shape[:-no_dims] + (-1,))
|
||||
|
||||
|
||||
def masked_mean(mask, value, dim, eps=1e-4):
|
||||
mask = mask.expand(*value.shape)
|
||||
return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim))
|
||||
|
||||
|
||||
def pts_to_distogram(pts, min_bin=2.3125, max_bin=21.6875, no_bins=64):
|
||||
boundaries = torch.linspace(
|
||||
min_bin, max_bin, no_bins - 1, device=pts.device
|
||||
)
|
||||
dists = torch.sqrt(
|
||||
torch.sum((pts.unsqueeze(-2) - pts.unsqueeze(-3)) ** 2, dim=-1)
|
||||
)
|
||||
return torch.bucketize(dists, boundaries)
|
||||
|
||||
|
||||
def dict_multimap(fn, dicts):
|
||||
first = dicts[0]
|
||||
new_dict = {}
|
||||
for k, v in first.items():
|
||||
all_v = [d[k] for d in dicts]
|
||||
if type(v) is dict:
|
||||
new_dict[k] = dict_multimap(fn, all_v)
|
||||
else:
|
||||
new_dict[k] = fn(all_v)
|
||||
|
||||
return new_dict
|
||||
|
||||
|
||||
def one_hot(x, v_bins):
|
||||
reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
|
||||
diffs = x[..., None] - reshaped_bins
|
||||
am = torch.argmin(torch.abs(diffs), dim=-1)
|
||||
return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
|
||||
|
||||
|
||||
def batched_gather(data, inds, dim=0, no_batch_dims=0):
|
||||
ranges = []
|
||||
for i, s in enumerate(data.shape[:no_batch_dims]):
|
||||
r = torch.arange(s)
|
||||
r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
|
||||
ranges.append(r)
|
||||
|
||||
remaining_dims = [
|
||||
slice(None) for _ in range(len(data.shape) - no_batch_dims)
|
||||
]
|
||||
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
|
||||
ranges.extend(remaining_dims)
|
||||
return data[ranges]
|
||||
|
||||
|
||||
# With tree_map, a poor man's JAX tree_map
|
||||
def dict_map(fn, dic, leaf_type):
|
||||
new_dict = {}
|
||||
for k, v in dic.items():
|
||||
if type(v) is dict:
|
||||
new_dict[k] = dict_map(fn, v, leaf_type)
|
||||
else:
|
||||
new_dict[k] = tree_map(fn, v, leaf_type)
|
||||
|
||||
return new_dict
|
||||
|
||||
|
||||
def tree_map(fn, tree, leaf_type):
|
||||
if isinstance(tree, dict):
|
||||
return dict_map(fn, tree, leaf_type)
|
||||
elif isinstance(tree, list):
|
||||
return [tree_map(fn, x, leaf_type) for x in tree]
|
||||
elif isinstance(tree, tuple):
|
||||
return tuple([tree_map(fn, x, leaf_type) for x in tree])
|
||||
elif isinstance(tree, leaf_type):
|
||||
return fn(tree)
|
||||
else:
|
||||
print(type(tree))
|
||||
raise ValueError("Not supported")
|
||||
|
||||
|
||||
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
|
||||
|
||||
def _fetch_dims(tree):
|
||||
shapes = []
|
||||
tree_type = type(tree)
|
||||
if tree_type is dict:
|
||||
for v in tree.values():
|
||||
shapes.extend(_fetch_dims(v))
|
||||
elif tree_type is list or tree_type is tuple:
|
||||
for t in tree:
|
||||
shapes.extend(_fetch_dims(t))
|
||||
elif tree_type is torch.Tensor:
|
||||
shapes.append(tree.shape)
|
||||
else:
|
||||
raise ValueError("Not supported")
|
||||
|
||||
return shapes
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _flat_idx_to_idx(
|
||||
flat_idx: int,
|
||||
dims: Tuple[int],
|
||||
) -> Tuple[int]:
|
||||
idx = []
|
||||
for d in reversed(dims):
|
||||
idx.append(flat_idx % d)
|
||||
flat_idx = flat_idx // d
|
||||
|
||||
return tuple(reversed(idx))
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _get_minimal_slice_set(
|
||||
start: Sequence[int],
|
||||
end: Sequence[int],
|
||||
dims: int,
|
||||
start_edges: Optional[Sequence[bool]] = None,
|
||||
end_edges: Optional[Sequence[bool]] = None,
|
||||
) -> Sequence[Tuple[int]]:
|
||||
"""
|
||||
Produces an ordered sequence of tensor slices that, when used in
|
||||
sequence on a tensor with shape dims, yields tensors that contain every
|
||||
leaf in the contiguous range [start, end]. Care is taken to yield a
|
||||
short sequence of slices, and perhaps even the shortest possible (I'm
|
||||
pretty sure it's the latter).
|
||||
|
||||
end is INCLUSIVE.
|
||||
"""
|
||||
# start_edges and end_edges both indicate whether, starting from any given
|
||||
# dimension, the start/end index is at the top/bottom edge of the
|
||||
# corresponding tensor, modeled as a tree
|
||||
def reduce_edge_list(ll):
|
||||
tally = 1
|
||||
for i in range(len(ll)):
|
||||
reversed_idx = -1 * (i + 1)
|
||||
ll[reversed_idx] *= tally
|
||||
tally = ll[reversed_idx]
|
||||
|
||||
if(start_edges is None):
|
||||
start_edges = [s == 0 for s in start]
|
||||
reduce_edge_list(start_edges)
|
||||
if(end_edges is None):
|
||||
end_edges = [e == (d - 1) for e,d in zip(end, dims)]
|
||||
reduce_edge_list(end_edges)
|
||||
|
||||
# Base cases. Either start/end are empty and we're done, or the final,
|
||||
# one-dimensional tensor can be simply sliced
|
||||
if(len(start) == 0):
|
||||
return [tuple()]
|
||||
elif(len(start) == 1):
|
||||
return [(slice(start[0], end[0] + 1),)]
|
||||
|
||||
slices = []
|
||||
path = []
|
||||
|
||||
# Dimensions common to start and end can be selected directly
|
||||
for s,e in zip(start, end):
|
||||
if(s == e):
|
||||
path.append(slice(s, s + 1))
|
||||
else:
|
||||
break
|
||||
|
||||
path = tuple(path)
|
||||
divergence_idx = len(path)
|
||||
|
||||
# start == end, and we're done
|
||||
if(divergence_idx == len(dims)):
|
||||
return [tuple(path)]
|
||||
|
||||
def upper():
|
||||
sdi = start[divergence_idx]
|
||||
return [
|
||||
path + (slice(sdi, sdi + 1),) + s for s in
|
||||
_get_minimal_slice_set(
|
||||
start[divergence_idx + 1:],
|
||||
[d - 1 for d in dims[divergence_idx + 1:]],
|
||||
dims[divergence_idx + 1:],
|
||||
start_edges=start_edges[divergence_idx + 1:],
|
||||
end_edges=[1 for _ in end_edges[divergence_idx + 1:]]
|
||||
)
|
||||
]
|
||||
|
||||
def lower():
|
||||
edi = end[divergence_idx]
|
||||
return [
|
||||
path + (slice(edi, edi + 1),) + s for s in
|
||||
_get_minimal_slice_set(
|
||||
[0 for _ in start[divergence_idx + 1:]],
|
||||
end[divergence_idx + 1:],
|
||||
dims[divergence_idx + 1:],
|
||||
start_edges=[1 for _ in start_edges[divergence_idx + 1:]],
|
||||
end_edges=end_edges[divergence_idx + 1:],
|
||||
)
|
||||
]
|
||||
|
||||
# If both start and end are at the edges of the subtree rooted at
|
||||
# divergence_idx, we can just select the whole subtree at once
|
||||
if(start_edges[divergence_idx] and end_edges[divergence_idx]):
|
||||
slices.append(
|
||||
path + (slice(start[divergence_idx], end[divergence_idx] + 1),)
|
||||
)
|
||||
# If just start is at the edge, we can grab almost all of the subtree,
|
||||
# treating only the ragged bottom edge as an edge case
|
||||
elif(start_edges[divergence_idx]):
|
||||
slices.append(
|
||||
path + (slice(start[divergence_idx], end[divergence_idx]),)
|
||||
)
|
||||
slices.extend(lower())
|
||||
# Analogous to the previous case, but the top is ragged this time
|
||||
elif(end_edges[divergence_idx]):
|
||||
slices.extend(upper())
|
||||
slices.append(
|
||||
path + (slice(start[divergence_idx] + 1, end[divergence_idx] + 1),)
|
||||
)
|
||||
# If both sides of the range are ragged, we need to handle both sides
|
||||
# separately. If there's contiguous meat in between them, we can index it
|
||||
# in one big chunk
|
||||
else:
|
||||
slices.extend(upper())
|
||||
middle_ground = end[divergence_idx] - start[divergence_idx]
|
||||
if(middle_ground > 1):
|
||||
slices.append(
|
||||
path + (slice(start[divergence_idx] + 1, end[divergence_idx]),)
|
||||
)
|
||||
slices.extend(lower())
|
||||
|
||||
return [tuple(s) for s in slices]
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk_slice(
|
||||
t: torch.Tensor,
|
||||
flat_start: int,
|
||||
flat_end: int,
|
||||
no_batch_dims: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Equivalent to
|
||||
|
||||
t.reshape((-1,) + t.shape[no_batch_dims:])[flat_start:flat_end]
|
||||
|
||||
but without the need for the initial reshape call, which can be
|
||||
memory-intensive in certain situations. The only reshape operations
|
||||
in this function are performed on sub-tensors that scale with
|
||||
(flat_end - flat_start), the chunk size.
|
||||
"""
|
||||
|
||||
batch_dims = t.shape[:no_batch_dims]
|
||||
start_idx = list(_flat_idx_to_idx(flat_start, batch_dims))
|
||||
# _get_minimal_slice_set is inclusive
|
||||
end_idx = list(_flat_idx_to_idx(flat_end - 1, batch_dims))
|
||||
|
||||
# Get an ordered list of slices to perform
|
||||
slices = _get_minimal_slice_set(
|
||||
start_idx,
|
||||
end_idx,
|
||||
batch_dims,
|
||||
)
|
||||
|
||||
sliced_tensors = [t[s] for s in slices]
|
||||
|
||||
return torch.cat(
|
||||
[s.view((-1,) + t.shape[no_batch_dims:]) for s in sliced_tensors]
|
||||
)
|
||||
|
||||
|
||||
def chunk_layer(
|
||||
layer: Callable,
|
||||
inputs: Dict[str, Any],
|
||||
chunk_size: int,
|
||||
no_batch_dims: int,
|
||||
low_mem: bool = False,
|
||||
) -> Any:
|
||||
"""
|
||||
Implements the "chunking" procedure described in section 1.11.8.
|
||||
|
||||
Layer outputs and inputs are assumed to be simple "pytrees,"
|
||||
consisting only of (arbitrarily nested) lists, tuples, and dicts with
|
||||
torch.Tensor leaves.
|
||||
|
||||
Args:
|
||||
layer:
|
||||
The layer to be applied chunk-wise
|
||||
inputs:
|
||||
A (non-nested) dictionary of keyworded inputs. All leaves must
|
||||
be tensors and must share the same batch dimensions.
|
||||
chunk_size:
|
||||
The number of sub-batches per chunk. If multiple batch
|
||||
dimensions are specified, a "sub-batch" is defined as a single
|
||||
indexing of all batch dimensions simultaneously (s.t. the
|
||||
number of sub-batches is the product of the batch dimensions).
|
||||
no_batch_dims:
|
||||
How many of the initial dimensions of each input tensor can
|
||||
be considered batch dimensions.
|
||||
low_mem:
|
||||
Avoids flattening potentially large input tensors. Unnecessary
|
||||
in most cases, and is ever so slightly slower than the default
|
||||
setting.
|
||||
Returns:
|
||||
The reassembled output of the layer on the inputs.
|
||||
"""
|
||||
if not (len(inputs) > 0):
|
||||
raise ValueError("Must provide at least one input")
|
||||
|
||||
initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
|
||||
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
|
||||
|
||||
def _prep_inputs(t):
|
||||
# TODO: make this more memory efficient. This sucks
|
||||
if(not low_mem):
|
||||
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
|
||||
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
|
||||
t = t.reshape(-1, *t.shape[no_batch_dims:])
|
||||
else:
|
||||
t = t.expand(orig_batch_dims + t.shape[no_batch_dims:])
|
||||
return t
|
||||
|
||||
prepped_inputs = tensor_tree_map(_prep_inputs, inputs)
|
||||
|
||||
flat_batch_dim = 1
|
||||
for d in orig_batch_dims:
|
||||
flat_batch_dim *= d
|
||||
|
||||
no_chunks = flat_batch_dim // chunk_size + (
|
||||
flat_batch_dim % chunk_size != 0
|
||||
)
|
||||
|
||||
i = 0
|
||||
out = None
|
||||
for _ in range(no_chunks):
|
||||
# Chunk the input
|
||||
if(not low_mem):
|
||||
select_chunk = (
|
||||
lambda t: t[i : i + chunk_size] if t.shape[0] != 1 else t
|
||||
)
|
||||
else:
|
||||
select_chunk = (
|
||||
partial(
|
||||
_chunk_slice,
|
||||
flat_start=i,
|
||||
flat_end=min(flat_batch_dim, i + chunk_size),
|
||||
no_batch_dims=len(orig_batch_dims)
|
||||
)
|
||||
)
|
||||
|
||||
chunks = tensor_tree_map(select_chunk, prepped_inputs)
|
||||
|
||||
# Run the layer on the chunk
|
||||
output_chunk = layer(**chunks)
|
||||
|
||||
# Allocate space for the output
|
||||
if out is None:
|
||||
allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:])
|
||||
out = tensor_tree_map(allocate, output_chunk)
|
||||
|
||||
# Put the chunk in its pre-allocated space
|
||||
out_type = type(output_chunk)
|
||||
if out_type is dict:
|
||||
def assign(d1, d2):
|
||||
for k, v in d1.items():
|
||||
if type(v) is dict:
|
||||
assign(v, d2[k])
|
||||
else:
|
||||
v[i : i + chunk_size] = d2[k]
|
||||
|
||||
assign(out, output_chunk)
|
||||
elif out_type is tuple:
|
||||
for x1, x2 in zip(out, output_chunk):
|
||||
x1[i : i + chunk_size] = x2
|
||||
elif out_type is torch.Tensor:
|
||||
out[i : i + chunk_size] = output_chunk
|
||||
else:
|
||||
raise ValueError("Not supported")
|
||||
|
||||
i += chunk_size
|
||||
|
||||
reshape = lambda t: t.view(orig_batch_dims + t.shape[1:])
|
||||
out = tensor_tree_map(reshape, out)
|
||||
|
||||
return out
|
|
@ -1,139 +0,0 @@
|
|||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partialmethod, partial
|
||||
import math
|
||||
from typing import Optional, List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .primitives import Linear, LayerNorm, Attention
|
||||
from .tensor_utils import (
|
||||
chunk_layer,
|
||||
permute_final_dims,
|
||||
flatten_final_dims,
|
||||
)
|
||||
|
||||
|
||||
class TriangleAttention(nn.Module):
|
||||
def __init__(
|
||||
self, c_in, c_hidden, no_heads, starting, inf=1e9
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
c_in:
|
||||
Input channel dimension
|
||||
c_hidden:
|
||||
Overall hidden channel dimension (not per-head)
|
||||
no_heads:
|
||||
Number of attention heads
|
||||
"""
|
||||
super(TriangleAttention, self).__init__()
|
||||
|
||||
self.c_in = c_in
|
||||
self.c_hidden = c_hidden
|
||||
self.no_heads = no_heads
|
||||
self.starting = starting
|
||||
self.inf = inf
|
||||
|
||||
self.layer_norm = LayerNorm(self.c_in)
|
||||
|
||||
self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
|
||||
|
||||
self.mha = Attention(
|
||||
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
|
||||
)
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunk(self,
|
||||
x: torch.Tensor,
|
||||
biases: List[torch.Tensor],
|
||||
chunk_size: int,
|
||||
) -> torch.Tensor:
|
||||
mha_inputs = {
|
||||
"q_x": x,
|
||||
"kv_x": x,
|
||||
"biases": biases,
|
||||
}
|
||||
return chunk_layer(
|
||||
partial(self.mha),
|
||||
mha_inputs,
|
||||
chunk_size=chunk_size,
|
||||
no_batch_dims=len(x.shape[:-2]),
|
||||
)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
[*, I, J, C_in] input tensor (e.g. the pair representation)
|
||||
Returns:
|
||||
[*, I, J, C_in] output tensor
|
||||
"""
|
||||
if mask is None:
|
||||
# [*, I, J]
|
||||
mask = x.new_ones(
|
||||
x.shape[:-1],
|
||||
)
|
||||
|
||||
# Shape annotations assume self.starting. Else, I and J are flipped
|
||||
if not self.starting:
|
||||
x = x.transpose(-2, -3)
|
||||
mask = mask.transpose(-1, -2)
|
||||
|
||||
# [*, I, J, C_in]
|
||||
x = self.layer_norm(x)
|
||||
|
||||
# [*, I, 1, 1, J]
|
||||
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
|
||||
|
||||
# [*, H, I, J]
|
||||
triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
|
||||
|
||||
# [*, 1, H, I, J]
|
||||
triangle_bias = triangle_bias.unsqueeze(-4)
|
||||
|
||||
biases = [mask_bias, triangle_bias]
|
||||
|
||||
if chunk_size is not None:
|
||||
x = self._chunk(x, biases, chunk_size)
|
||||
else:
|
||||
x = self.mha(q_x=x, kv_x=x, biases=biases)
|
||||
|
||||
if not self.starting:
|
||||
x = x.transpose(-2, -3)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TriangleAttentionStartingNode(TriangleAttention):
|
||||
"""
|
||||
Implements Algorithm 13.
|
||||
"""
|
||||
|
||||
__init__ = partialmethod(TriangleAttention.__init__, starting=True)
|
||||
|
||||
|
||||
class TriangleAttentionEndingNode(TriangleAttention):
|
||||
"""
|
||||
Implements Algorithm 14.
|
||||
"""
|
||||
|
||||
__init__ = partialmethod(TriangleAttention.__init__, starting=False)
|
|
@ -1,127 +0,0 @@
|
|||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partialmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .primitives import Linear, LayerNorm
|
||||
from .tensor_utils import permute_final_dims
|
||||
|
||||
|
||||
class TriangleMultiplicativeUpdate(nn.Module):
|
||||
"""
|
||||
Implements Algorithms 11 and 12.
|
||||
"""
|
||||
def __init__(self, c_z, c_hidden, _outgoing=True):
|
||||
"""
|
||||
Args:
|
||||
c_z:
|
||||
Input channel dimension
|
||||
c:
|
||||
Hidden channel dimension
|
||||
"""
|
||||
super(TriangleMultiplicativeUpdate, self).__init__()
|
||||
self.c_z = c_z
|
||||
self.c_hidden = c_hidden
|
||||
self._outgoing = _outgoing
|
||||
|
||||
self.linear_a_p = Linear(self.c_z, self.c_hidden)
|
||||
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
|
||||
self.linear_b_p = Linear(self.c_z, self.c_hidden)
|
||||
self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
|
||||
self.linear_g = Linear(self.c_z, self.c_z, init="gating")
|
||||
self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
|
||||
|
||||
self.layer_norm_in = LayerNorm(self.c_z)
|
||||
self.layer_norm_out = LayerNorm(self.c_hidden)
|
||||
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def _combine_projections(self,
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError("This method needs to be overridden")
|
||||
|
||||
def forward(self,
|
||||
z: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
[*, N_res, N_res, C_z] input tensor
|
||||
mask:
|
||||
[*, N_res, N_res] input mask
|
||||
Returns:
|
||||
[*, N_res, N_res, C_z] output tensor
|
||||
"""
|
||||
if mask is None:
|
||||
mask = z.new_ones(z.shape[:-1])
|
||||
|
||||
mask = mask.unsqueeze(-1)
|
||||
|
||||
z = self.layer_norm_in(z)
|
||||
a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z))
|
||||
a = a * mask
|
||||
b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z))
|
||||
b = b * mask
|
||||
x = self._combine_projections(a, b)
|
||||
x = self.layer_norm_out(x)
|
||||
x = self.linear_z(x)
|
||||
g = self.sigmoid(self.linear_g(z))
|
||||
z = x * g
|
||||
|
||||
return z
|
||||
|
||||
|
||||
class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
|
||||
"""
|
||||
Implements Algorithm 11.
|
||||
"""
|
||||
def _combine_projections(self,
|
||||
a: torch.Tensor, # [*, N_i, N_k, C]
|
||||
b: torch.Tensor, # [*, N_j, N_k, C]
|
||||
):
|
||||
# [*, C, N_i, N_j]
|
||||
p = torch.matmul(
|
||||
permute_final_dims(a, (2, 0, 1)),
|
||||
permute_final_dims(b, (2, 1, 0)),
|
||||
)
|
||||
|
||||
# [*, N_i, N_j, C]
|
||||
return permute_final_dims(p, (1, 2, 0))
|
||||
|
||||
|
||||
class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
|
||||
"""
|
||||
Implements Algorithm 12.
|
||||
"""
|
||||
def _combine_projections(self,
|
||||
a: torch.Tensor, # [*, N_k, N_i, C]
|
||||
b: torch.Tensor, # [*, N_k, N_j, C]
|
||||
):
|
||||
# [*, C, N_i, N_j]
|
||||
p = torch.matmul(
|
||||
permute_final_dims(a, (2, 1, 0)),
|
||||
permute_final_dims(b, (2, 0, 1)),
|
||||
)
|
||||
|
||||
# [*, N_i, N_j, C]
|
||||
return permute_final_dims(p, (1, 2, 0))
|
||||
|
|
@ -0,0 +1,164 @@
|
|||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from fastfold.model.nn.evoformer import EvoformerBlock
|
||||
HAS_REPO = True
|
||||
except:
|
||||
HAS_REPO = False
|
||||
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
|
||||
if CODEGEN_AVAILABLE and is_compatible_with_meta():
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
from colossalai.fx.profiler import MetaTensor
|
||||
from colossalai.fx.tracer.experimental import ColoTracer, symbolic_trace
|
||||
|
||||
|
||||
def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair, node_mask, pair_mask):
|
||||
# for memory test
|
||||
# torch.cuda.reset_peak_memory_stats()
|
||||
# now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
# with torch.no_grad():
|
||||
# node1 = node.clone()
|
||||
# pair1 = pair.clone()
|
||||
# gm(node1, pair1)
|
||||
# new_now_mem = torch.cuda.memory_allocated() / 1024**2
|
||||
# new_max_mem = torch.cuda.max_memory_allocated() / 1024**2
|
||||
# print(
|
||||
# "autochunk now mem:%.2f max mem:%.2f"
|
||||
# % (new_now_mem - now_mem, new_max_mem - now_mem)
|
||||
# )
|
||||
|
||||
# test forward
|
||||
model = model.cuda()
|
||||
with torch.no_grad():
|
||||
non_fx_out = model(node, pair, node_mask, pair_mask)
|
||||
fx_out = gm(node, pair, node_mask, pair_mask)
|
||||
|
||||
assert torch.allclose(non_fx_out[0], fx_out[0],
|
||||
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(non_fx_out[0] - fx_out[0]))
|
||||
assert torch.allclose(non_fx_out[1], fx_out[1],
|
||||
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(non_fx_out[1] - fx_out[1]))
|
||||
|
||||
|
||||
def _build_openfold():
|
||||
model = EvoformerBlock(
|
||||
c_m=256,
|
||||
c_z=128,
|
||||
c_hidden_msa_att=32,
|
||||
c_hidden_opm=32,
|
||||
c_hidden_mul=128,
|
||||
c_hidden_pair_att=32,
|
||||
no_heads_msa=8,
|
||||
no_heads_pair=4,
|
||||
transition_n=4,
|
||||
msa_dropout=0.15,
|
||||
pair_dropout=0.15,
|
||||
inf=1e4,
|
||||
eps=1e-4,
|
||||
is_multimer=False,
|
||||
).eval().cuda()
|
||||
return model
|
||||
|
||||
|
||||
def _test_evoformer_codegen(rank, msa_len, pair_len, max_memory):
|
||||
# launch colossalai
|
||||
colossalai.launch(
|
||||
config={},
|
||||
rank=rank,
|
||||
world_size=1,
|
||||
host="localhost",
|
||||
port=free_port(),
|
||||
backend="nccl",
|
||||
)
|
||||
|
||||
# build model and input
|
||||
model = _build_openfold()
|
||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
node_mask = torch.randn(1, msa_len, pair_len).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
pair_mask = torch.randn(1, pair_len, pair_len).cuda()
|
||||
|
||||
# trace the meta graph and setup codegen
|
||||
meta_graph = symbolic_trace(
|
||||
model,
|
||||
meta_args={
|
||||
"m": node.to(torch.device("meta")),
|
||||
"z": pair.to(torch.device("meta")),
|
||||
"msa_mask": node_mask.to(torch.device("meta")),
|
||||
"pair_mask": pair_mask.to(torch.device("meta")),
|
||||
},
|
||||
concrete_args={
|
||||
"chunk_size": None,
|
||||
"_mask_trans": True,
|
||||
},
|
||||
)
|
||||
interp = MetaInfoProp(meta_graph)
|
||||
interp.propagate(
|
||||
MetaTensor(node, fake_device="cuda:0"),
|
||||
MetaTensor(pair, fake_device="cuda:0"),
|
||||
MetaTensor(node_mask, fake_device="cuda:0"),
|
||||
MetaTensor(pair_mask, fake_device="cuda:0"),
|
||||
)
|
||||
# codegen = AutoChunkCodeGen(meta_graph, max_memory=max_memory)
|
||||
|
||||
# trace and recompile
|
||||
# MetaInfoProp requires symbolic_trace but CodeGen requires ColoTracer
|
||||
graph = ColoTracer().trace(
|
||||
model,
|
||||
meta_args={
|
||||
"m": node.to(torch.device("meta")),
|
||||
"z": pair.to(torch.device("meta")),
|
||||
"msa_mask": node_mask.to(torch.device("meta")),
|
||||
"pair_mask": pair_mask.to(torch.device("meta")),
|
||||
},
|
||||
concrete_args={
|
||||
"chunk_size": None,
|
||||
"_mask_trans": True,
|
||||
},
|
||||
)
|
||||
# graph.set_codegen(codegen)
|
||||
gm = ColoGraphModule(model, graph)
|
||||
gm.recompile()
|
||||
|
||||
# assert we have inserted chunk
|
||||
code = graph.python_code("self").src
|
||||
assert "chunk_size" in code
|
||||
# print(code)
|
||||
|
||||
_test_fwd(model, gm, node, pair, node_mask, pair_mask)
|
||||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0",
|
||||
)
|
||||
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
|
||||
@pytest.mark.parametrize("msa_len", [32])
|
||||
@pytest.mark.parametrize("pair_len", [64])
|
||||
def test_evoformer_codegen(msa_len, pair_len, max_memory):
|
||||
run_func = partial(
|
||||
_test_evoformer_codegen,
|
||||
msa_len=msa_len,
|
||||
pair_len=pair_len,
|
||||
max_memory=max_memory,
|
||||
)
|
||||
mp.spawn(run_func, nprocs=1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_evoformer_codegen(0, 32, 64, 25)
|
|
@ -5,6 +5,12 @@ import torch
|
|||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from simple_evoformer import base_evoformer
|
||||
HAS_REPO = True
|
||||
except:
|
||||
HAS_REPO = False
|
||||
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx import ColoTracer
|
||||
|
@ -13,7 +19,6 @@ from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABL
|
|||
from colossalai.fx.graph_module import ColoGraphModule
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_autochunk.evoformer.evoformer import evoformer_base
|
||||
|
||||
if CODEGEN_AVAILABLE and is_compatible_with_meta():
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
|
@ -48,7 +53,7 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
|
|||
torch.abs(non_fx_out[1] - fx_out[1]))
|
||||
|
||||
|
||||
def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
|
||||
def _test_simple_evoformer_codegen(rank, msa_len, pair_len, max_memory):
|
||||
# launch colossalai
|
||||
colossalai.launch(
|
||||
config={},
|
||||
|
@ -60,7 +65,7 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
|
|||
)
|
||||
|
||||
# build model and input
|
||||
model = evoformer_base().cuda()
|
||||
model = base_evoformer().cuda()
|
||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
|
||||
|
@ -95,13 +100,14 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
|
|||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta()), reason='torch version is lower than 1.12.0')
|
||||
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
|
||||
reason='torch version is lower than 1.12.0')
|
||||
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
|
||||
@pytest.mark.parametrize("msa_len", [32])
|
||||
@pytest.mark.parametrize("pair_len", [64])
|
||||
def test_autochunk_codegen(msa_len, pair_len, max_memory):
|
||||
def test_simple_evoformer_codegen(msa_len, pair_len, max_memory):
|
||||
run_func = partial(
|
||||
_test_autochunk_codegen,
|
||||
_test_simple_evoformer_codegen,
|
||||
msa_len=msa_len,
|
||||
pair_len=pair_len,
|
||||
max_memory=max_memory,
|
||||
|
@ -110,4 +116,4 @@ def test_autochunk_codegen(msa_len, pair_len, max_memory):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_autochunk_codegen(0, 32, 64, 25)
|
||||
_test_simple_evoformer_codegen(0, 32, 64, 25)
|
|
@ -5,13 +5,18 @@ import torch
|
|||
import torch.fx
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
try:
|
||||
from simple_evoformer import base_evoformer
|
||||
HAS_REPO = True
|
||||
except:
|
||||
HAS_REPO = False
|
||||
|
||||
import colossalai
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.fx._compatibility import is_compatible_with_meta
|
||||
from colossalai.fx.codegen.activation_checkpoint_codegen import CODEGEN_AVAILABLE
|
||||
from colossalai.fx.passes.meta_info_prop import MetaInfoProp
|
||||
from colossalai.utils import free_port
|
||||
from tests.test_autochunk.evoformer.evoformer import evoformer_base
|
||||
|
||||
if CODEGEN_AVAILABLE and is_compatible_with_meta():
|
||||
from colossalai.autochunk.autochunk_codegen import AutoChunkCodeGen
|
||||
|
@ -57,7 +62,7 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
|
|||
)
|
||||
|
||||
|
||||
def _test_autochunk_search(rank, msa_len, pair_len, max_memory):
|
||||
def _test_simple_evoformer_search(rank, msa_len, pair_len, max_memory):
|
||||
# launch colossalai
|
||||
colossalai.launch(
|
||||
config={},
|
||||
|
@ -69,7 +74,7 @@ def _test_autochunk_search(rank, msa_len, pair_len, max_memory):
|
|||
)
|
||||
|
||||
# build model and input
|
||||
model = evoformer_base().cuda()
|
||||
model = base_evoformer().cuda()
|
||||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
|
||||
|
@ -84,13 +89,14 @@ def _test_autochunk_search(rank, msa_len, pair_len, max_memory):
|
|||
gpc.destroy()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta()), reason="torch version is lower than 1.12.0")
|
||||
@pytest.mark.skipif(not (CODEGEN_AVAILABLE and is_compatible_with_meta() and HAS_REPO),
|
||||
reason="torch version is lower than 1.12.0")
|
||||
@pytest.mark.parametrize("max_memory", [None, 20, 25, 30])
|
||||
@pytest.mark.parametrize("msa_len", [32])
|
||||
@pytest.mark.parametrize("pair_len", [64])
|
||||
def test_autochunk_search(msa_len, pair_len, max_memory):
|
||||
def test_simple_evoformer_search(msa_len, pair_len, max_memory):
|
||||
run_func = partial(
|
||||
_test_autochunk_search,
|
||||
_test_simple_evoformer_search,
|
||||
msa_len=msa_len,
|
||||
pair_len=pair_len,
|
||||
max_memory=max_memory,
|
||||
|
@ -99,4 +105,4 @@ def test_autochunk_search(msa_len, pair_len, max_memory):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_autochunk_search(0, 32, 64, 20)
|
||||
_test_simple_evoformer_search(0, 32, 64, 20)
|
Loading…
Reference in New Issue