mirror of https://github.com/hpcaitech/ColossalAI
add evoformer
parent
78cfe4362b
commit
86f2a31474
|
@ -0,0 +1,47 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from .msa import MSAStack
|
||||||
|
from .ops import OutProductMean
|
||||||
|
from .triangle import PairStack
|
||||||
|
|
||||||
|
|
||||||
|
class EvoformerBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_node, d_pair):
|
||||||
|
super(EvoformerBlock, self).__init__()
|
||||||
|
|
||||||
|
self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15)
|
||||||
|
self.communication = OutProductMean(n_feat=d_node, n_feat_out=d_pair, n_feat_proj=32)
|
||||||
|
self.pair_stack = PairStack(d_pair=d_pair)
|
||||||
|
|
||||||
|
def forward(self, node, pair):
|
||||||
|
node = node + self.msa_stack(node, pair)
|
||||||
|
pair = pair + self.communication(node)
|
||||||
|
pair = pair + self.pair_stack(pair)
|
||||||
|
return node, pair
|
||||||
|
|
||||||
|
|
||||||
|
class Evoformer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_node, d_pair):
|
||||||
|
super(Evoformer, self).__init__()
|
||||||
|
|
||||||
|
self.blocks = nn.ModuleList()
|
||||||
|
for _ in range(3):
|
||||||
|
self.blocks.append(EvoformerBlock(d_node, d_pair))
|
||||||
|
|
||||||
|
def forward(self, node, pair):
|
||||||
|
for b in self.blocks:
|
||||||
|
node, pair = b(node, pair)
|
||||||
|
return node, pair
|
||||||
|
|
||||||
|
def evoformer_base():
|
||||||
|
return Evoformer(d_node=256, d_pair=128)
|
||||||
|
|
||||||
|
|
||||||
|
def evoformer_large():
|
||||||
|
return Evoformer(d_node=512, d_pair=256)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['Evoformer', 'evoformer_base', 'evoformer_large']
|
|
@ -0,0 +1,29 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def glorot_uniform_af(x, gain=1.0):
|
||||||
|
"""
|
||||||
|
initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different:
|
||||||
|
In PyTorch:
|
||||||
|
[feature_out, feature_in, n_head ...]
|
||||||
|
In Jax:
|
||||||
|
[... n_head, feature_in, feature_out]
|
||||||
|
However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like:
|
||||||
|
[feature_in, n_head, feature_out]
|
||||||
|
|
||||||
|
In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors
|
||||||
|
"""
|
||||||
|
fan_in, fan_out = x.shape[-2:]
|
||||||
|
if len(x.shape) > 2:
|
||||||
|
receptive_field_size = np.prod(x.shape[:-2])
|
||||||
|
fan_in *= receptive_field_size
|
||||||
|
fan_out *= receptive_field_size
|
||||||
|
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
|
||||||
|
dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
||||||
|
|
||||||
|
nn.init.uniform_(x, -dev, dev)
|
||||||
|
|
||||||
|
return x
|
|
@ -0,0 +1,19 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def bias_sigmod_ele(y, bias, z):
|
||||||
|
return torch.sigmoid(y + bias) * z
|
||||||
|
|
||||||
|
|
||||||
|
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor,
|
||||||
|
residual: torch.Tensor, prob: float) -> torch.Tensor:
|
||||||
|
out = (x + bias) * F.dropout(dropmask, p=prob, training=True)
|
||||||
|
out = residual + out
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor,
|
||||||
|
dropout_mask: torch.Tensor, Z_raw: torch.Tensor,
|
||||||
|
prob: float) -> torch.Tensor:
|
||||||
|
return Z_raw + F.dropout(dropout_mask, p=prob, training=True) * (g * (ab + b))
|
|
@ -0,0 +1,95 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
from torch.nn import LayerNorm
|
||||||
|
|
||||||
|
from .kernel import bias_dropout_add
|
||||||
|
from .ops import SelfAttention, Transition
|
||||||
|
|
||||||
|
|
||||||
|
class MSARowAttentionWithPairBias(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15):
|
||||||
|
super(MSARowAttentionWithPairBias, self).__init__()
|
||||||
|
self.d_node = d_node
|
||||||
|
self.d_pair = d_pair
|
||||||
|
self.c = c
|
||||||
|
self.n_head = n_head
|
||||||
|
self.p_drop = p_drop
|
||||||
|
|
||||||
|
self.layernormM = LayerNorm(d_node)
|
||||||
|
self.layernormZ = LayerNorm(d_pair)
|
||||||
|
|
||||||
|
_init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]),
|
||||||
|
std=1.0 / math.sqrt(d_pair))
|
||||||
|
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True)
|
||||||
|
|
||||||
|
self.attention = SelfAttention(qkv_dim=d_node,
|
||||||
|
c=c,
|
||||||
|
n_head=n_head,
|
||||||
|
out_dim=d_node,
|
||||||
|
gating=True,
|
||||||
|
last_bias_fuse=True)
|
||||||
|
|
||||||
|
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True)
|
||||||
|
|
||||||
|
def forward(self, M_raw, Z):
|
||||||
|
## Input projections
|
||||||
|
M = self.layernormM(M_raw)
|
||||||
|
Z = self.layernormZ(Z)
|
||||||
|
b = F.linear(Z, self.linear_b_weights)
|
||||||
|
b = b.permute(0, 3, 1, 2)
|
||||||
|
# b = rearrange(b, 'b q k h -> b h q k')
|
||||||
|
|
||||||
|
M = self.attention(M, b)
|
||||||
|
dropout_mask = torch.ones_like(M[:, 0:1, :, :], device=M.device, dtype=M.dtype)
|
||||||
|
|
||||||
|
return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop)
|
||||||
|
|
||||||
|
|
||||||
|
class MSAColumnAttention(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_node, c=32, n_head=8):
|
||||||
|
super(MSAColumnAttention, self).__init__()
|
||||||
|
self.d_node = d_node
|
||||||
|
self.c = c
|
||||||
|
self.n_head = n_head
|
||||||
|
|
||||||
|
self.layernormM = LayerNorm(d_node)
|
||||||
|
self.attention = SelfAttention(qkv_dim=d_node,
|
||||||
|
c=c,
|
||||||
|
n_head=n_head,
|
||||||
|
out_dim=d_node,
|
||||||
|
gating=True)
|
||||||
|
|
||||||
|
def forward(self, M_raw):
|
||||||
|
M = M_raw.transpose(-2, -3)
|
||||||
|
M = self.layernormM(M)
|
||||||
|
|
||||||
|
M = self.attention(M)
|
||||||
|
|
||||||
|
M = M.transpose(-2, -3)
|
||||||
|
return M_raw + M
|
||||||
|
|
||||||
|
|
||||||
|
class MSAStack(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_node, d_pair, p_drop=0.15):
|
||||||
|
super(MSAStack, self).__init__()
|
||||||
|
|
||||||
|
self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node,
|
||||||
|
d_pair=d_pair,
|
||||||
|
p_drop=p_drop)
|
||||||
|
|
||||||
|
self.MSAColumnAttention = MSAColumnAttention(d_node=d_node)
|
||||||
|
self.MSATransition = Transition(d=d_node)
|
||||||
|
|
||||||
|
def forward(self, node, pair):
|
||||||
|
node = self.MSARowAttentionWithPairBias(node, pair)
|
||||||
|
node = self.MSAColumnAttention(node)
|
||||||
|
node = self.MSATransition(node)
|
||||||
|
|
||||||
|
return node
|
|
@ -0,0 +1,176 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from einops import rearrange
|
||||||
|
from torch.nn import LayerNorm
|
||||||
|
|
||||||
|
from .initializer import glorot_uniform_af
|
||||||
|
from .kernel import bias_sigmod_ele
|
||||||
|
|
||||||
|
|
||||||
|
class DropoutRowwise(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, p):
|
||||||
|
super(DropoutRowwise, self).__init__()
|
||||||
|
self.p = p
|
||||||
|
self.dropout = nn.Dropout(p=p)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
dropout_mask = torch.ones_like(x[:, 0:1, :, :])
|
||||||
|
dropout_mask = self.dropout(dropout_mask)
|
||||||
|
return dropout_mask * x
|
||||||
|
|
||||||
|
|
||||||
|
class DropoutColumnwise(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, p):
|
||||||
|
super(DropoutColumnwise, self).__init__()
|
||||||
|
self.p = p
|
||||||
|
self.dropout = nn.Dropout(p=p)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
dropout_mask = torch.ones_like(x[:, :, 0:1, :])
|
||||||
|
dropout_mask = self.dropout(dropout_mask)
|
||||||
|
return dropout_mask * x
|
||||||
|
|
||||||
|
|
||||||
|
class Transition(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d, n=4):
|
||||||
|
super(Transition, self).__init__()
|
||||||
|
self.norm = LayerNorm(d)
|
||||||
|
self.linear1 = Linear(d, n * d, initializer='relu')
|
||||||
|
self.linear2 = Linear(n * d, d, initializer='zeros')
|
||||||
|
|
||||||
|
def forward(self, src):
|
||||||
|
x = self.norm(src)
|
||||||
|
x = self.linear2(F.relu(self.linear1(x)))
|
||||||
|
return src + x
|
||||||
|
|
||||||
|
|
||||||
|
class OutProductMean(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32):
|
||||||
|
super(OutProductMean, self).__init__()
|
||||||
|
|
||||||
|
self.layernormM = LayerNorm(n_feat)
|
||||||
|
self.linear_a = Linear(n_feat, n_feat_proj)
|
||||||
|
self.linear_b = Linear(n_feat, n_feat_proj)
|
||||||
|
|
||||||
|
self.o_linear = Linear(n_feat_proj * n_feat_proj,
|
||||||
|
n_feat_out,
|
||||||
|
initializer='zero',
|
||||||
|
use_bias=True)
|
||||||
|
|
||||||
|
def forward(self, M):
|
||||||
|
M = self.layernormM(M)
|
||||||
|
left_act = self.linear_a(M)
|
||||||
|
right_act = self.linear_b(M)
|
||||||
|
|
||||||
|
O = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous()
|
||||||
|
# O = rearrange(O, 'b i j d e -> b i j (d e)')
|
||||||
|
O = O.reshape(O.shape[0], O.shape[1], O.shape[2], -1)
|
||||||
|
Z = self.o_linear(O)
|
||||||
|
|
||||||
|
return Z
|
||||||
|
|
||||||
|
|
||||||
|
class Linear(nn.Linear):
|
||||||
|
"""
|
||||||
|
A Linear layer with built-in nonstandard initializations. Called just
|
||||||
|
like torch.nn.Linear.
|
||||||
|
Implements the initializers in 1.11.4, plus some additional ones found
|
||||||
|
in the code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
feature_in: int,
|
||||||
|
feature_out: int,
|
||||||
|
initializer: str = 'linear',
|
||||||
|
use_bias: bool = True,
|
||||||
|
bias_init: float = 0.,
|
||||||
|
):
|
||||||
|
super(Linear, self).__init__(feature_in, feature_out, bias=use_bias)
|
||||||
|
|
||||||
|
self.use_bias = use_bias
|
||||||
|
if initializer == 'linear':
|
||||||
|
glorot_uniform_af(self.weight, gain=1.0)
|
||||||
|
elif initializer == 'relu':
|
||||||
|
glorot_uniform_af(self.weight, gain=2.0)
|
||||||
|
elif initializer == 'zeros':
|
||||||
|
nn.init.zeros_(self.weight)
|
||||||
|
if self.use_bias:
|
||||||
|
with torch.no_grad():
|
||||||
|
self.bias.fill_(bias_init)
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False):
|
||||||
|
super(SelfAttention, self).__init__()
|
||||||
|
self.qkv_dim = qkv_dim
|
||||||
|
self.c = c
|
||||||
|
self.n_head = n_head
|
||||||
|
self.out_dim = out_dim
|
||||||
|
self.gating = gating
|
||||||
|
self.last_bias_fuse = last_bias_fuse
|
||||||
|
|
||||||
|
self.scaling = self.c**(-0.5)
|
||||||
|
|
||||||
|
# self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear')
|
||||||
|
self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
|
||||||
|
self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
|
||||||
|
self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
|
||||||
|
|
||||||
|
if gating:
|
||||||
|
self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,)))
|
||||||
|
self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False)
|
||||||
|
|
||||||
|
self.o_linear = Linear(n_head * c,
|
||||||
|
out_dim,
|
||||||
|
initializer='zero',
|
||||||
|
use_bias=(not last_bias_fuse))
|
||||||
|
|
||||||
|
def forward(self, in_data, nonbatched_bias=None):
|
||||||
|
"""
|
||||||
|
:param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim]
|
||||||
|
:param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv]
|
||||||
|
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
|
||||||
|
"""
|
||||||
|
|
||||||
|
# qkv = self.to_qkv(in_data).chunk(3, dim=-1)
|
||||||
|
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
|
||||||
|
|
||||||
|
q = self.to_q(in_data)
|
||||||
|
k = self.to_k(in_data)
|
||||||
|
v = self.to_k(in_data)
|
||||||
|
|
||||||
|
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head),
|
||||||
|
# [q, k, v])
|
||||||
|
q, k, v = map(lambda t: t.view(t.shape[0], t.shape[1], t.shape[2], self.n_head, -1).permute(0, 1, 3, 2, 4),
|
||||||
|
[q, k, v])
|
||||||
|
|
||||||
|
q = q * self.scaling
|
||||||
|
|
||||||
|
logits = torch.matmul(q, k.transpose(-1, -2))
|
||||||
|
|
||||||
|
if nonbatched_bias is not None:
|
||||||
|
logits += nonbatched_bias.unsqueeze(1)
|
||||||
|
weights = torch.softmax(logits, dim=-1)
|
||||||
|
# weights = softmax(logits)
|
||||||
|
|
||||||
|
weighted_avg = torch.matmul(weights, v)
|
||||||
|
# weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
|
||||||
|
weighted_avg = weighted_avg.permute(0, 1, 3, 2, 4)
|
||||||
|
weighted_avg = weighted_avg.reshape(weighted_avg.shape[0], weighted_avg.shape[1], weighted_avg.shape[2], -1)
|
||||||
|
|
||||||
|
if self.gating:
|
||||||
|
gate_values = self.gating_linear(in_data)
|
||||||
|
weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg)
|
||||||
|
|
||||||
|
output = self.o_linear(weighted_avg)
|
||||||
|
return output
|
|
@ -0,0 +1,192 @@
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import LayerNorm
|
||||||
|
|
||||||
|
from .kernel import bias_dropout_add, bias_ele_dropout_residual
|
||||||
|
from .ops import Linear, SelfAttention, Transition
|
||||||
|
|
||||||
|
|
||||||
|
def permute_final_dims(tensor, inds):
|
||||||
|
zero_index = -1 * len(inds)
|
||||||
|
first_inds = list(range(len(tensor.shape[:zero_index])))
|
||||||
|
return tensor.permute(first_inds + [zero_index + i for i in inds])
|
||||||
|
|
||||||
|
|
||||||
|
class TriangleMultiplicationOutgoing(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_pair, p_drop, c=128):
|
||||||
|
super(TriangleMultiplicationOutgoing, self).__init__()
|
||||||
|
self.d_pair = d_pair
|
||||||
|
self.c = c
|
||||||
|
|
||||||
|
self.layernorm1 = LayerNorm(d_pair)
|
||||||
|
self.left_projection = Linear(d_pair, c)
|
||||||
|
self.right_projection = Linear(d_pair, c)
|
||||||
|
self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||||
|
self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||||
|
|
||||||
|
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
|
||||||
|
self.layernorm2 = LayerNorm(c)
|
||||||
|
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
|
||||||
|
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||||
|
|
||||||
|
self.p_drop = p_drop
|
||||||
|
|
||||||
|
def forward(self, Z_raw):
|
||||||
|
Z = self.layernorm1(Z_raw)
|
||||||
|
left_proj_act = self.left_projection(Z)
|
||||||
|
right_proj_act = self.right_projection(Z)
|
||||||
|
|
||||||
|
left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z))
|
||||||
|
right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z))
|
||||||
|
|
||||||
|
g = torch.sigmoid(self.output_gate(Z))
|
||||||
|
# p = torch.matmul(
|
||||||
|
# permute_final_dims(left_proj_act, (2, 0, 1)),
|
||||||
|
# permute_final_dims(right_proj_act, (2, 1, 0)),
|
||||||
|
# )
|
||||||
|
# ab = permute_final_dims(p, (1, 2, 0))
|
||||||
|
|
||||||
|
ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act)
|
||||||
|
ab = self.output_projection(self.layernorm2(ab))
|
||||||
|
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
|
||||||
|
return bias_ele_dropout_residual(ab,
|
||||||
|
self.output_bias,
|
||||||
|
g,
|
||||||
|
dropout_mask,
|
||||||
|
Z_raw,
|
||||||
|
prob=self.p_drop)
|
||||||
|
|
||||||
|
|
||||||
|
class TriangleMultiplicationIncoming(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_pair, p_drop, c=128):
|
||||||
|
super(TriangleMultiplicationIncoming, self).__init__()
|
||||||
|
self.d_pair = d_pair
|
||||||
|
self.c = c
|
||||||
|
|
||||||
|
self.layernorm1 = LayerNorm(d_pair)
|
||||||
|
self.left_projection = Linear(d_pair, c)
|
||||||
|
self.right_projection = Linear(d_pair, c)
|
||||||
|
self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||||
|
self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||||
|
|
||||||
|
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
|
||||||
|
self.layernorm2 = LayerNorm(c)
|
||||||
|
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
|
||||||
|
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||||
|
|
||||||
|
self.p_drop = p_drop
|
||||||
|
|
||||||
|
def forward(self, Z_raw):
|
||||||
|
Z = self.layernorm1(Z_raw)
|
||||||
|
left_proj_act = self.left_projection(Z)
|
||||||
|
right_proj_act = self.right_projection(Z)
|
||||||
|
|
||||||
|
left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z))
|
||||||
|
right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z))
|
||||||
|
|
||||||
|
g = torch.sigmoid(self.output_gate(Z))
|
||||||
|
# p = torch.matmul(
|
||||||
|
# permute_final_dims(left_proj_act, (2, 1, 0)),
|
||||||
|
# permute_final_dims(right_proj_act, (2, 0, 1)),
|
||||||
|
# )
|
||||||
|
# ab = permute_final_dims(p, (1, 2, 0))
|
||||||
|
|
||||||
|
ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act)
|
||||||
|
ab = self.output_projection(self.layernorm2(ab))
|
||||||
|
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
|
||||||
|
return bias_ele_dropout_residual(ab,
|
||||||
|
self.output_bias,
|
||||||
|
g,
|
||||||
|
dropout_mask,
|
||||||
|
Z_raw,
|
||||||
|
prob=self.p_drop)
|
||||||
|
|
||||||
|
|
||||||
|
class TriangleAttentionStartingNode(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_pair, p_drop, c=32, n_head=4):
|
||||||
|
super(TriangleAttentionStartingNode, self).__init__()
|
||||||
|
self.d_pair = d_pair
|
||||||
|
self.c = c
|
||||||
|
self.n_head = n_head
|
||||||
|
self.p_drop = p_drop
|
||||||
|
|
||||||
|
self.layernorm1 = LayerNorm(d_pair)
|
||||||
|
_init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]),
|
||||||
|
std=1.0 / math.sqrt(d_pair))
|
||||||
|
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
|
||||||
|
self.attention = SelfAttention(qkv_dim=d_pair,
|
||||||
|
c=c,
|
||||||
|
n_head=n_head,
|
||||||
|
out_dim=d_pair,
|
||||||
|
gating=True,
|
||||||
|
last_bias_fuse=True)
|
||||||
|
|
||||||
|
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||||
|
|
||||||
|
def forward(self, Z_raw):
|
||||||
|
Z = self.layernorm1(Z_raw)
|
||||||
|
b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights)
|
||||||
|
|
||||||
|
Z = self.attention(Z, b)
|
||||||
|
|
||||||
|
dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype)
|
||||||
|
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
|
||||||
|
|
||||||
|
|
||||||
|
class TriangleAttentionEndingNode(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_pair, p_drop, c=32, n_head=4):
|
||||||
|
super(TriangleAttentionEndingNode, self).__init__()
|
||||||
|
self.d_pair = d_pair
|
||||||
|
self.c = c
|
||||||
|
self.n_head = n_head
|
||||||
|
self.p_drop = p_drop
|
||||||
|
|
||||||
|
self.layernorm1 = LayerNorm(d_pair)
|
||||||
|
_init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]),
|
||||||
|
std=1.0 / math.sqrt(d_pair))
|
||||||
|
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
|
||||||
|
self.attention = SelfAttention(qkv_dim=d_pair,
|
||||||
|
c=c,
|
||||||
|
n_head=n_head,
|
||||||
|
out_dim=d_pair,
|
||||||
|
gating=True,
|
||||||
|
last_bias_fuse=True)
|
||||||
|
|
||||||
|
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||||
|
|
||||||
|
def forward(self, Z_raw):
|
||||||
|
Z = Z_raw.transpose(-2, -3)
|
||||||
|
Z = self.layernorm1(Z)
|
||||||
|
b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights)
|
||||||
|
|
||||||
|
Z = self.attention(Z, b)
|
||||||
|
|
||||||
|
Z = Z.transpose(-2, -3)
|
||||||
|
dropout_mask = torch.ones_like(Z[:, :, 0:1, :], device=Z.device, dtype=Z.dtype)
|
||||||
|
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
|
||||||
|
|
||||||
|
|
||||||
|
class PairStack(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d_pair, p_drop=0.25):
|
||||||
|
super(PairStack, self).__init__()
|
||||||
|
|
||||||
|
self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, p_drop=p_drop)
|
||||||
|
self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, p_drop=p_drop)
|
||||||
|
self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, p_drop=p_drop)
|
||||||
|
self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, p_drop=p_drop)
|
||||||
|
self.PairTransition = Transition(d=d_pair)
|
||||||
|
|
||||||
|
def forward(self, pair):
|
||||||
|
pair = self.TriangleMultiplicationOutgoing(pair)
|
||||||
|
pair = self.TriangleMultiplicationIncoming(pair)
|
||||||
|
pair = self.TriangleAttentionStartingNode(pair)
|
||||||
|
pair = self.TriangleAttentionEndingNode(pair)
|
||||||
|
pair = self.PairTransition(pair)
|
||||||
|
return pair
|
Loading…
Reference in New Issue