mirror of https://github.com/hpcaitech/ColossalAI
init openfold
parent
efe6fe3a33
commit
289f3a45c2
|
@ -0,0 +1,59 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .msa import MSAStack
|
||||
from .ops import OutProductMean
|
||||
from .triangle import PairStack
|
||||
|
||||
|
||||
def print_memory(init_mem, text=None):
|
||||
now_mem = torch.cuda.memory_allocated() / 1024 ** 2 - init_mem
|
||||
max_mem = torch.cuda.max_memory_allocated() / 1024 ** 2 - init_mem
|
||||
print("%s now:%.2f max:%.2f" % ("" if text is None else text, now_mem, max_mem))
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
||||
class EvoformerBlock(nn.Module):
|
||||
|
||||
def __init__(self, d_node, d_pair):
|
||||
super(EvoformerBlock, self).__init__()
|
||||
|
||||
self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15)
|
||||
self.communication = OutProductMean(n_feat=d_node, n_feat_out=d_pair, n_feat_proj=32)
|
||||
self.pair_stack = PairStack(d_pair=d_pair)
|
||||
|
||||
def forward(self, node, pair):
|
||||
node = self.msa_stack(node, pair)
|
||||
pair = pair + self.communication(node)
|
||||
pair = self.pair_stack(pair)
|
||||
return node, pair
|
||||
|
||||
|
||||
class Evoformer(nn.Module):
|
||||
|
||||
def __init__(self, d_node, d_pair):
|
||||
super(Evoformer, self).__init__()
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
for _ in range(1):
|
||||
self.blocks.append(EvoformerBlock(d_node, d_pair))
|
||||
|
||||
def forward(self, node, pair):
|
||||
for b in self.blocks:
|
||||
node, pair = b(node, pair)
|
||||
return node, pair
|
||||
|
||||
|
||||
def evoformer_tiny():
|
||||
return Evoformer(d_node=64, d_pair=32)
|
||||
|
||||
|
||||
def evoformer_base():
|
||||
return Evoformer(d_node=256, d_pair=128)
|
||||
|
||||
|
||||
def evoformer_large():
|
||||
return Evoformer(d_node=512, d_pair=256)
|
||||
|
||||
|
||||
__all__ = ['Evoformer', 'evoformer_base', 'evoformer_large']
|
|
@ -0,0 +1,29 @@
|
|||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def glorot_uniform_af(x, gain=1.0):
|
||||
"""
|
||||
initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different:
|
||||
In PyTorch:
|
||||
[feature_out, feature_in, n_head ...]
|
||||
In Jax:
|
||||
[... n_head, feature_in, feature_out]
|
||||
However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like:
|
||||
[feature_in, n_head, feature_out]
|
||||
|
||||
In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors
|
||||
"""
|
||||
fan_in, fan_out = x.shape[-2:]
|
||||
if len(x.shape) > 2:
|
||||
receptive_field_size = np.prod(x.shape[:-2])
|
||||
fan_in *= receptive_field_size
|
||||
fan_out *= receptive_field_size
|
||||
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
|
||||
dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
|
||||
|
||||
nn.init.uniform_(x, -dev, dev)
|
||||
|
||||
return x
|
|
@ -0,0 +1,19 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def bias_sigmod_ele(y, bias, z):
|
||||
return torch.sigmoid(y + bias) * z
|
||||
|
||||
|
||||
def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor,
|
||||
residual: torch.Tensor, prob: float) -> torch.Tensor:
|
||||
out = (x + bias) * F.dropout(dropmask, p=prob, training=False)
|
||||
out = residual + out
|
||||
return out
|
||||
|
||||
|
||||
def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor,
|
||||
dropout_mask: torch.Tensor, Z_raw: torch.Tensor,
|
||||
prob: float) -> torch.Tensor:
|
||||
return Z_raw + F.dropout(dropout_mask, p=prob, training=True) * (g * (ab + b))
|
|
@ -0,0 +1,95 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from .kernel import bias_dropout_add
|
||||
from .ops import SelfAttention, Transition
|
||||
|
||||
|
||||
class MSARowAttentionWithPairBias(nn.Module):
|
||||
|
||||
def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15):
|
||||
super(MSARowAttentionWithPairBias, self).__init__()
|
||||
self.d_node = d_node
|
||||
self.d_pair = d_pair
|
||||
self.c = c
|
||||
self.n_head = n_head
|
||||
self.p_drop = p_drop
|
||||
|
||||
self.layernormM = LayerNorm(d_node)
|
||||
self.layernormZ = LayerNorm(d_pair)
|
||||
|
||||
_init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]),
|
||||
std=1.0 / math.sqrt(d_pair))
|
||||
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True)
|
||||
|
||||
self.attention = SelfAttention(qkv_dim=d_node,
|
||||
c=c,
|
||||
n_head=n_head,
|
||||
out_dim=d_node,
|
||||
gating=True,
|
||||
last_bias_fuse=True)
|
||||
|
||||
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True)
|
||||
|
||||
def forward(self, M_raw, Z):
|
||||
## Input projections
|
||||
M = self.layernormM(M_raw)
|
||||
Z = self.layernormZ(Z)
|
||||
b = F.linear(Z, self.linear_b_weights)
|
||||
b = b.permute(0, 3, 1, 2)
|
||||
# b = rearrange(b, 'b q k h -> b h q k')
|
||||
|
||||
M = self.attention(M, b)
|
||||
dropout_mask = torch.ones_like(M[:, 0:1, :, :]).to(M.device).to(M.dtype)
|
||||
|
||||
return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop)
|
||||
|
||||
|
||||
class MSAColumnAttention(nn.Module):
|
||||
|
||||
def __init__(self, d_node, c=32, n_head=8):
|
||||
super(MSAColumnAttention, self).__init__()
|
||||
self.d_node = d_node
|
||||
self.c = c
|
||||
self.n_head = n_head
|
||||
|
||||
self.layernormM = LayerNorm(d_node)
|
||||
self.attention = SelfAttention(qkv_dim=d_node,
|
||||
c=c,
|
||||
n_head=n_head,
|
||||
out_dim=d_node,
|
||||
gating=True)
|
||||
|
||||
def forward(self, M_raw):
|
||||
M = M_raw.transpose(-2, -3)
|
||||
M = self.layernormM(M)
|
||||
|
||||
M = self.attention(M)
|
||||
|
||||
M = M.transpose(-2, -3)
|
||||
return M_raw + M
|
||||
|
||||
|
||||
class MSAStack(nn.Module):
|
||||
|
||||
def __init__(self, d_node, d_pair, p_drop=0.15):
|
||||
super(MSAStack, self).__init__()
|
||||
|
||||
self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node,
|
||||
d_pair=d_pair,
|
||||
p_drop=p_drop)
|
||||
|
||||
self.MSAColumnAttention = MSAColumnAttention(d_node=d_node)
|
||||
self.MSATransition = Transition(d=d_node)
|
||||
|
||||
def forward(self, node, pair):
|
||||
node = self.MSARowAttentionWithPairBias(node, pair)
|
||||
node = self.MSAColumnAttention(node)
|
||||
node = self.MSATransition(node)
|
||||
|
||||
return node
|
|
@ -0,0 +1,176 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from .initializer import glorot_uniform_af
|
||||
from .kernel import bias_sigmod_ele
|
||||
|
||||
|
||||
class DropoutRowwise(nn.Module):
|
||||
|
||||
def __init__(self, p):
|
||||
super(DropoutRowwise, self).__init__()
|
||||
self.p = p
|
||||
self.dropout = nn.Dropout(p=p)
|
||||
|
||||
def forward(self, x):
|
||||
dropout_mask = torch.ones_like(x[:, 0:1, :, :])
|
||||
dropout_mask = self.dropout(dropout_mask)
|
||||
return dropout_mask * x
|
||||
|
||||
|
||||
class DropoutColumnwise(nn.Module):
|
||||
|
||||
def __init__(self, p):
|
||||
super(DropoutColumnwise, self).__init__()
|
||||
self.p = p
|
||||
self.dropout = nn.Dropout(p=p)
|
||||
|
||||
def forward(self, x):
|
||||
dropout_mask = torch.ones_like(x[:, :, 0:1, :])
|
||||
dropout_mask = self.dropout(dropout_mask)
|
||||
return dropout_mask * x
|
||||
|
||||
|
||||
class Transition(nn.Module):
|
||||
|
||||
def __init__(self, d, n=4):
|
||||
super(Transition, self).__init__()
|
||||
self.norm = LayerNorm(d)
|
||||
self.linear1 = Linear(d, n * d, initializer='relu')
|
||||
self.linear2 = Linear(n * d, d, initializer='zeros')
|
||||
|
||||
def forward(self, src):
|
||||
x = self.norm(src)
|
||||
x = self.linear2(F.relu(self.linear1(x)))
|
||||
return src + x
|
||||
|
||||
|
||||
class OutProductMean(nn.Module):
|
||||
|
||||
def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32):
|
||||
super(OutProductMean, self).__init__()
|
||||
|
||||
self.layernormM = LayerNorm(n_feat)
|
||||
self.linear_a = Linear(n_feat, n_feat_proj)
|
||||
self.linear_b = Linear(n_feat, n_feat_proj)
|
||||
|
||||
self.o_linear = Linear(n_feat_proj * n_feat_proj,
|
||||
n_feat_out,
|
||||
initializer='zero',
|
||||
use_bias=True)
|
||||
|
||||
def forward(self, M):
|
||||
M = self.layernormM(M)
|
||||
left_act = self.linear_a(M)
|
||||
right_act = self.linear_b(M)
|
||||
|
||||
O = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous()
|
||||
# O = rearrange(O, 'b i j d e -> b i j (d e)')
|
||||
O = O.reshape(O.shape[0], O.shape[1], O.shape[2], -1)
|
||||
Z = self.o_linear(O)
|
||||
|
||||
return Z
|
||||
|
||||
|
||||
class Linear(nn.Linear):
|
||||
"""
|
||||
A Linear layer with built-in nonstandard initializations. Called just
|
||||
like torch.nn.Linear.
|
||||
Implements the initializers in 1.11.4, plus some additional ones found
|
||||
in the code.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_in: int,
|
||||
feature_out: int,
|
||||
initializer: str = 'linear',
|
||||
use_bias: bool = True,
|
||||
bias_init: float = 0.,
|
||||
):
|
||||
super(Linear, self).__init__(feature_in, feature_out, bias=use_bias)
|
||||
|
||||
self.use_bias = use_bias
|
||||
if initializer == 'linear':
|
||||
glorot_uniform_af(self.weight, gain=1.0)
|
||||
elif initializer == 'relu':
|
||||
glorot_uniform_af(self.weight, gain=2.0)
|
||||
elif initializer == 'zeros':
|
||||
nn.init.zeros_(self.weight)
|
||||
if self.use_bias:
|
||||
with torch.no_grad():
|
||||
self.bias.fill_(bias_init)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
"""
|
||||
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
|
||||
"""
|
||||
|
||||
def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False):
|
||||
super(SelfAttention, self).__init__()
|
||||
self.qkv_dim = qkv_dim
|
||||
self.c = c
|
||||
self.n_head = n_head
|
||||
self.out_dim = out_dim
|
||||
self.gating = gating
|
||||
self.last_bias_fuse = last_bias_fuse
|
||||
|
||||
self.scaling = self.c**(-0.5)
|
||||
|
||||
# self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear')
|
||||
self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
|
||||
self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
|
||||
self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
|
||||
|
||||
if gating:
|
||||
self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,)))
|
||||
self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False)
|
||||
|
||||
self.o_linear = Linear(n_head * c,
|
||||
out_dim,
|
||||
initializer='zero',
|
||||
use_bias=(not last_bias_fuse))
|
||||
|
||||
def forward(self, in_data, nonbatched_bias=None):
|
||||
"""
|
||||
:param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim]
|
||||
:param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv]
|
||||
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
|
||||
"""
|
||||
|
||||
# qkv = self.to_qkv(in_data).chunk(3, dim=-1)
|
||||
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
|
||||
|
||||
q = self.to_q(in_data)
|
||||
k = self.to_k(in_data)
|
||||
v = self.to_v(in_data)
|
||||
|
||||
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head),
|
||||
# [q, k, v])
|
||||
q, k, v = map(lambda t: t.view(t.shape[0], t.shape[1], t.shape[2], self.n_head, -1).permute(0, 1, 3, 2, 4),
|
||||
[q, k, v])
|
||||
|
||||
q = q * self.scaling
|
||||
|
||||
logits = torch.matmul(q, k.transpose(-1, -2))
|
||||
|
||||
if nonbatched_bias is not None:
|
||||
logits += nonbatched_bias.unsqueeze(1)
|
||||
weights = torch.softmax(logits, dim=-1)
|
||||
# weights = softmax(logits)
|
||||
|
||||
weighted_avg = torch.matmul(weights, v)
|
||||
# weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
|
||||
weighted_avg = weighted_avg.permute(0, 1, 3, 2, 4)
|
||||
weighted_avg = weighted_avg.reshape(weighted_avg.shape[0], weighted_avg.shape[1], weighted_avg.shape[2], -1)
|
||||
|
||||
if self.gating:
|
||||
gate_values = self.gating_linear(in_data)
|
||||
weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg)
|
||||
|
||||
output = self.o_linear(weighted_avg)
|
||||
return output
|
|
@ -0,0 +1,192 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import LayerNorm
|
||||
|
||||
from .kernel import bias_dropout_add, bias_ele_dropout_residual
|
||||
from .ops import Linear, SelfAttention, Transition
|
||||
|
||||
|
||||
def permute_final_dims(tensor, inds):
|
||||
zero_index = -1 * len(inds)
|
||||
first_inds = list(range(len(tensor.shape[:zero_index])))
|
||||
return tensor.permute(first_inds + [zero_index + i for i in inds])
|
||||
|
||||
|
||||
class TriangleMultiplicationOutgoing(nn.Module):
|
||||
|
||||
def __init__(self, d_pair, p_drop, c=128):
|
||||
super(TriangleMultiplicationOutgoing, self).__init__()
|
||||
self.d_pair = d_pair
|
||||
self.c = c
|
||||
|
||||
self.layernorm1 = LayerNorm(d_pair)
|
||||
self.left_projection = Linear(d_pair, c)
|
||||
self.right_projection = Linear(d_pair, c)
|
||||
self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||
self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||
|
||||
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
|
||||
self.layernorm2 = LayerNorm(c)
|
||||
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
|
||||
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||
|
||||
self.p_drop = p_drop
|
||||
|
||||
def forward(self, Z_raw):
|
||||
Z = self.layernorm1(Z_raw)
|
||||
left_proj_act = self.left_projection(Z)
|
||||
right_proj_act = self.right_projection(Z)
|
||||
|
||||
left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z))
|
||||
right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z))
|
||||
|
||||
g = torch.sigmoid(self.output_gate(Z))
|
||||
# p = torch.matmul(
|
||||
# permute_final_dims(left_proj_act, (2, 0, 1)),
|
||||
# permute_final_dims(right_proj_act, (2, 1, 0)),
|
||||
# )
|
||||
# ab = permute_final_dims(p, (1, 2, 0))
|
||||
|
||||
ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act)
|
||||
ab = self.output_projection(self.layernorm2(ab))
|
||||
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
|
||||
return bias_ele_dropout_residual(ab,
|
||||
self.output_bias,
|
||||
g,
|
||||
dropout_mask,
|
||||
Z_raw,
|
||||
prob=self.p_drop)
|
||||
|
||||
|
||||
class TriangleMultiplicationIncoming(nn.Module):
|
||||
|
||||
def __init__(self, d_pair, p_drop, c=128):
|
||||
super(TriangleMultiplicationIncoming, self).__init__()
|
||||
self.d_pair = d_pair
|
||||
self.c = c
|
||||
|
||||
self.layernorm1 = LayerNorm(d_pair)
|
||||
self.left_projection = Linear(d_pair, c)
|
||||
self.right_projection = Linear(d_pair, c)
|
||||
self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||
self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
|
||||
|
||||
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
|
||||
self.layernorm2 = LayerNorm(c)
|
||||
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
|
||||
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||
|
||||
self.p_drop = p_drop
|
||||
|
||||
def forward(self, Z_raw):
|
||||
Z = self.layernorm1(Z_raw)
|
||||
left_proj_act = self.left_projection(Z)
|
||||
right_proj_act = self.right_projection(Z)
|
||||
|
||||
left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z))
|
||||
right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z))
|
||||
|
||||
g = torch.sigmoid(self.output_gate(Z))
|
||||
# p = torch.matmul(
|
||||
# permute_final_dims(left_proj_act, (2, 1, 0)),
|
||||
# permute_final_dims(right_proj_act, (2, 0, 1)),
|
||||
# )
|
||||
# ab = permute_final_dims(p, (1, 2, 0))
|
||||
|
||||
ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act)
|
||||
ab = self.output_projection(self.layernorm2(ab))
|
||||
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
|
||||
return bias_ele_dropout_residual(ab,
|
||||
self.output_bias,
|
||||
g,
|
||||
dropout_mask,
|
||||
Z_raw,
|
||||
prob=self.p_drop)
|
||||
|
||||
|
||||
class TriangleAttentionStartingNode(nn.Module):
|
||||
|
||||
def __init__(self, d_pair, p_drop, c=32, n_head=4):
|
||||
super(TriangleAttentionStartingNode, self).__init__()
|
||||
self.d_pair = d_pair
|
||||
self.c = c
|
||||
self.n_head = n_head
|
||||
self.p_drop = p_drop
|
||||
|
||||
self.layernorm1 = LayerNorm(d_pair)
|
||||
_init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]),
|
||||
std=1.0 / math.sqrt(d_pair))
|
||||
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
|
||||
self.attention = SelfAttention(qkv_dim=d_pair,
|
||||
c=c,
|
||||
n_head=n_head,
|
||||
out_dim=d_pair,
|
||||
gating=True,
|
||||
last_bias_fuse=True)
|
||||
|
||||
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||
|
||||
def forward(self, Z_raw):
|
||||
Z = self.layernorm1(Z_raw)
|
||||
b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights)
|
||||
|
||||
Z = self.attention(Z, b)
|
||||
|
||||
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
|
||||
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
|
||||
|
||||
|
||||
class TriangleAttentionEndingNode(nn.Module):
|
||||
|
||||
def __init__(self, d_pair, p_drop, c=32, n_head=4):
|
||||
super(TriangleAttentionEndingNode, self).__init__()
|
||||
self.d_pair = d_pair
|
||||
self.c = c
|
||||
self.n_head = n_head
|
||||
self.p_drop = p_drop
|
||||
|
||||
self.layernorm1 = LayerNorm(d_pair)
|
||||
_init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]),
|
||||
std=1.0 / math.sqrt(d_pair))
|
||||
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
|
||||
self.attention = SelfAttention(qkv_dim=d_pair,
|
||||
c=c,
|
||||
n_head=n_head,
|
||||
out_dim=d_pair,
|
||||
gating=True,
|
||||
last_bias_fuse=True)
|
||||
|
||||
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
|
||||
|
||||
def forward(self, Z_raw):
|
||||
Z = Z_raw.transpose(-2, -3)
|
||||
Z = self.layernorm1(Z)
|
||||
b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights)
|
||||
|
||||
Z = self.attention(Z, b)
|
||||
|
||||
Z = Z.transpose(-2, -3)
|
||||
dropout_mask = torch.ones_like(Z[:, :, 0:1, :]).to(Z.device).to(Z.dtype)
|
||||
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
|
||||
|
||||
|
||||
class PairStack(nn.Module):
|
||||
|
||||
def __init__(self, d_pair, p_drop=0.25):
|
||||
super(PairStack, self).__init__()
|
||||
|
||||
self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, p_drop=p_drop)
|
||||
self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, p_drop=p_drop)
|
||||
self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, p_drop=p_drop)
|
||||
self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, p_drop=p_drop)
|
||||
self.PairTransition = Transition(d=d_pair)
|
||||
|
||||
def forward(self, pair):
|
||||
pair = self.TriangleMultiplicationOutgoing(pair)
|
||||
pair = self.TriangleMultiplicationIncoming(pair)
|
||||
pair = self.TriangleAttentionStartingNode(pair)
|
||||
pair = self.TriangleAttentionEndingNode(pair)
|
||||
pair = self.PairTransition(pair)
|
||||
return pair
|
Loading…
Reference in New Issue