mirror of https://github.com/hpcaitech/ColossalAI
96 lines
3.1 KiB
Python
96 lines
3.1 KiB
Python
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
from torch.nn import LayerNorm
|
|
|
|
from .kernel import bias_dropout_add
|
|
from .ops import SelfAttention, Transition
|
|
|
|
|
|
class MSARowAttentionWithPairBias(nn.Module):
|
|
|
|
def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15):
|
|
super(MSARowAttentionWithPairBias, self).__init__()
|
|
self.d_node = d_node
|
|
self.d_pair = d_pair
|
|
self.c = c
|
|
self.n_head = n_head
|
|
self.p_drop = p_drop
|
|
|
|
self.layernormM = LayerNorm(d_node)
|
|
self.layernormZ = LayerNorm(d_pair)
|
|
|
|
_init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]),
|
|
std=1.0 / math.sqrt(d_pair))
|
|
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True)
|
|
|
|
self.attention = SelfAttention(qkv_dim=d_node,
|
|
c=c,
|
|
n_head=n_head,
|
|
out_dim=d_node,
|
|
gating=True,
|
|
last_bias_fuse=True)
|
|
|
|
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True)
|
|
|
|
def forward(self, M_raw, Z):
|
|
## Input projections
|
|
M = self.layernormM(M_raw)
|
|
Z = self.layernormZ(Z)
|
|
b = F.linear(Z, self.linear_b_weights)
|
|
b = b.permute(0, 3, 1, 2)
|
|
# b = rearrange(b, 'b q k h -> b h q k')
|
|
|
|
M = self.attention(M, b)
|
|
dropout_mask = torch.ones_like(M[:, 0:1, :, :]).to(M.device).to(M.dtype)
|
|
|
|
return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop)
|
|
|
|
|
|
class MSAColumnAttention(nn.Module):
|
|
|
|
def __init__(self, d_node, c=32, n_head=8):
|
|
super(MSAColumnAttention, self).__init__()
|
|
self.d_node = d_node
|
|
self.c = c
|
|
self.n_head = n_head
|
|
|
|
self.layernormM = LayerNorm(d_node)
|
|
self.attention = SelfAttention(qkv_dim=d_node,
|
|
c=c,
|
|
n_head=n_head,
|
|
out_dim=d_node,
|
|
gating=True)
|
|
|
|
def forward(self, M_raw):
|
|
M = M_raw.transpose(-2, -3)
|
|
M = self.layernormM(M)
|
|
|
|
M = self.attention(M)
|
|
|
|
M = M.transpose(-2, -3)
|
|
return M_raw + M
|
|
|
|
|
|
class MSAStack(nn.Module):
|
|
|
|
def __init__(self, d_node, d_pair, p_drop=0.15):
|
|
super(MSAStack, self).__init__()
|
|
|
|
self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node,
|
|
d_pair=d_pair,
|
|
p_drop=p_drop)
|
|
|
|
self.MSAColumnAttention = MSAColumnAttention(d_node=d_node)
|
|
self.MSATransition = Transition(d=d_node)
|
|
|
|
def forward(self, node, pair):
|
|
node = self.MSARowAttentionWithPairBias(node, pair)
|
|
node = self.MSAColumnAttention(node)
|
|
node = self.MSATransition(node)
|
|
|
|
return node
|