From 86f2a3147415f2afe53019cd7b9d9414de1510e9 Mon Sep 17 00:00:00 2001 From: oahzxl Date: Wed, 2 Nov 2022 15:12:08 +0800 Subject: [PATCH] add evoformer --- evoformer/evoformer.py | 47 ++++++++++ evoformer/initializer.py | 29 ++++++ evoformer/kernel.py | 19 ++++ evoformer/msa.py | 95 +++++++++++++++++++ evoformer/ops.py | 176 +++++++++++++++++++++++++++++++++++ evoformer/triangle.py | 192 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 558 insertions(+) create mode 100644 evoformer/evoformer.py create mode 100755 evoformer/initializer.py create mode 100644 evoformer/kernel.py create mode 100644 evoformer/msa.py create mode 100755 evoformer/ops.py create mode 100644 evoformer/triangle.py diff --git a/evoformer/evoformer.py b/evoformer/evoformer.py new file mode 100644 index 000000000..ef3df2769 --- /dev/null +++ b/evoformer/evoformer.py @@ -0,0 +1,47 @@ +import torch +import torch.nn as nn + +from .msa import MSAStack +from .ops import OutProductMean +from .triangle import PairStack + + +class EvoformerBlock(nn.Module): + + def __init__(self, d_node, d_pair): + super(EvoformerBlock, self).__init__() + + self.msa_stack = MSAStack(d_node, d_pair, p_drop=0.15) + self.communication = OutProductMean(n_feat=d_node, n_feat_out=d_pair, n_feat_proj=32) + self.pair_stack = PairStack(d_pair=d_pair) + + def forward(self, node, pair): + node = node + self.msa_stack(node, pair) + pair = pair + self.communication(node) + pair = pair + self.pair_stack(pair) + return node, pair + + +class Evoformer(nn.Module): + + def __init__(self, d_node, d_pair): + super(Evoformer, self).__init__() + + self.blocks = nn.ModuleList() + for _ in range(3): + self.blocks.append(EvoformerBlock(d_node, d_pair)) + + def forward(self, node, pair): + for b in self.blocks: + node, pair = b(node, pair) + return node, pair + +def evoformer_base(): + return Evoformer(d_node=256, d_pair=128) + + +def evoformer_large(): + return Evoformer(d_node=512, d_pair=256) + + +__all__ = ['Evoformer', 'evoformer_base', 'evoformer_large'] diff --git a/evoformer/initializer.py b/evoformer/initializer.py new file mode 100755 index 000000000..c6ce0659e --- /dev/null +++ b/evoformer/initializer.py @@ -0,0 +1,29 @@ +import math + +import numpy as np +import torch.nn as nn + + +def glorot_uniform_af(x, gain=1.0): + """ + initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different: + In PyTorch: + [feature_out, feature_in, n_head ...] + In Jax: + [... n_head, feature_in, feature_out] + However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like: + [feature_in, n_head, feature_out] + + In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors + """ + fan_in, fan_out = x.shape[-2:] + if len(x.shape) > 2: + receptive_field_size = np.prod(x.shape[:-2]) + fan_in *= receptive_field_size + fan_out *= receptive_field_size + std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) + dev = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation + + nn.init.uniform_(x, -dev, dev) + + return x diff --git a/evoformer/kernel.py b/evoformer/kernel.py new file mode 100644 index 000000000..2655901a2 --- /dev/null +++ b/evoformer/kernel.py @@ -0,0 +1,19 @@ +import torch +import torch.nn.functional as F + + +def bias_sigmod_ele(y, bias, z): + return torch.sigmoid(y + bias) * z + + +def bias_dropout_add(x: torch.Tensor, bias: torch.Tensor, dropmask: torch.Tensor, + residual: torch.Tensor, prob: float) -> torch.Tensor: + out = (x + bias) * F.dropout(dropmask, p=prob, training=True) + out = residual + out + return out + + +def bias_ele_dropout_residual(ab: torch.Tensor, b: torch.Tensor, g: torch.Tensor, + dropout_mask: torch.Tensor, Z_raw: torch.Tensor, + prob: float) -> torch.Tensor: + return Z_raw + F.dropout(dropout_mask, p=prob, training=True) * (g * (ab + b)) \ No newline at end of file diff --git a/evoformer/msa.py b/evoformer/msa.py new file mode 100644 index 000000000..ccefa38c4 --- /dev/null +++ b/evoformer/msa.py @@ -0,0 +1,95 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn import LayerNorm + +from .kernel import bias_dropout_add +from .ops import SelfAttention, Transition + + +class MSARowAttentionWithPairBias(nn.Module): + + def __init__(self, d_node, d_pair, c=32, n_head=8, p_drop=0.15): + super(MSARowAttentionWithPairBias, self).__init__() + self.d_node = d_node + self.d_pair = d_pair + self.c = c + self.n_head = n_head + self.p_drop = p_drop + + self.layernormM = LayerNorm(d_node) + self.layernormZ = LayerNorm(d_pair) + + _init_weights = torch.nn.init.normal_(torch.zeros([n_head, d_pair]), + std=1.0 / math.sqrt(d_pair)) + self.linear_b_weights = nn.parameter.Parameter(data=_init_weights, requires_grad=True) + + self.attention = SelfAttention(qkv_dim=d_node, + c=c, + n_head=n_head, + out_dim=d_node, + gating=True, + last_bias_fuse=True) + + self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_node,)), requires_grad=True) + + def forward(self, M_raw, Z): + ## Input projections + M = self.layernormM(M_raw) + Z = self.layernormZ(Z) + b = F.linear(Z, self.linear_b_weights) + b = b.permute(0, 3, 1, 2) + # b = rearrange(b, 'b q k h -> b h q k') + + M = self.attention(M, b) + dropout_mask = torch.ones_like(M[:, 0:1, :, :], device=M.device, dtype=M.dtype) + + return bias_dropout_add(M, self.out_bias, dropout_mask, M_raw, prob=self.p_drop) + + +class MSAColumnAttention(nn.Module): + + def __init__(self, d_node, c=32, n_head=8): + super(MSAColumnAttention, self).__init__() + self.d_node = d_node + self.c = c + self.n_head = n_head + + self.layernormM = LayerNorm(d_node) + self.attention = SelfAttention(qkv_dim=d_node, + c=c, + n_head=n_head, + out_dim=d_node, + gating=True) + + def forward(self, M_raw): + M = M_raw.transpose(-2, -3) + M = self.layernormM(M) + + M = self.attention(M) + + M = M.transpose(-2, -3) + return M_raw + M + + +class MSAStack(nn.Module): + + def __init__(self, d_node, d_pair, p_drop=0.15): + super(MSAStack, self).__init__() + + self.MSARowAttentionWithPairBias = MSARowAttentionWithPairBias(d_node=d_node, + d_pair=d_pair, + p_drop=p_drop) + + self.MSAColumnAttention = MSAColumnAttention(d_node=d_node) + self.MSATransition = Transition(d=d_node) + + def forward(self, node, pair): + node = self.MSARowAttentionWithPairBias(node, pair) + node = self.MSAColumnAttention(node) + node = self.MSATransition(node) + + return node diff --git a/evoformer/ops.py b/evoformer/ops.py new file mode 100755 index 000000000..ddbba441d --- /dev/null +++ b/evoformer/ops.py @@ -0,0 +1,176 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn import LayerNorm + +from .initializer import glorot_uniform_af +from .kernel import bias_sigmod_ele + + +class DropoutRowwise(nn.Module): + + def __init__(self, p): + super(DropoutRowwise, self).__init__() + self.p = p + self.dropout = nn.Dropout(p=p) + + def forward(self, x): + dropout_mask = torch.ones_like(x[:, 0:1, :, :]) + dropout_mask = self.dropout(dropout_mask) + return dropout_mask * x + + +class DropoutColumnwise(nn.Module): + + def __init__(self, p): + super(DropoutColumnwise, self).__init__() + self.p = p + self.dropout = nn.Dropout(p=p) + + def forward(self, x): + dropout_mask = torch.ones_like(x[:, :, 0:1, :]) + dropout_mask = self.dropout(dropout_mask) + return dropout_mask * x + + +class Transition(nn.Module): + + def __init__(self, d, n=4): + super(Transition, self).__init__() + self.norm = LayerNorm(d) + self.linear1 = Linear(d, n * d, initializer='relu') + self.linear2 = Linear(n * d, d, initializer='zeros') + + def forward(self, src): + x = self.norm(src) + x = self.linear2(F.relu(self.linear1(x))) + return src + x + + +class OutProductMean(nn.Module): + + def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32): + super(OutProductMean, self).__init__() + + self.layernormM = LayerNorm(n_feat) + self.linear_a = Linear(n_feat, n_feat_proj) + self.linear_b = Linear(n_feat, n_feat_proj) + + self.o_linear = Linear(n_feat_proj * n_feat_proj, + n_feat_out, + initializer='zero', + use_bias=True) + + def forward(self, M): + M = self.layernormM(M) + left_act = self.linear_a(M) + right_act = self.linear_b(M) + + O = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous() + # O = rearrange(O, 'b i j d e -> b i j (d e)') + O = O.reshape(O.shape[0], O.shape[1], O.shape[2], -1) + Z = self.o_linear(O) + + return Z + + +class Linear(nn.Linear): + """ + A Linear layer with built-in nonstandard initializations. Called just + like torch.nn.Linear. + Implements the initializers in 1.11.4, plus some additional ones found + in the code. + """ + + def __init__( + self, + feature_in: int, + feature_out: int, + initializer: str = 'linear', + use_bias: bool = True, + bias_init: float = 0., + ): + super(Linear, self).__init__(feature_in, feature_out, bias=use_bias) + + self.use_bias = use_bias + if initializer == 'linear': + glorot_uniform_af(self.weight, gain=1.0) + elif initializer == 'relu': + glorot_uniform_af(self.weight, gain=2.0) + elif initializer == 'zeros': + nn.init.zeros_(self.weight) + if self.use_bias: + with torch.no_grad(): + self.bias.fill_(bias_init) + + +class SelfAttention(nn.Module): + """ + Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors + """ + + def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False): + super(SelfAttention, self).__init__() + self.qkv_dim = qkv_dim + self.c = c + self.n_head = n_head + self.out_dim = out_dim + self.gating = gating + self.last_bias_fuse = last_bias_fuse + + self.scaling = self.c**(-0.5) + + # self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear') + self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) + self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) + self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) + + if gating: + self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,))) + self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False) + + self.o_linear = Linear(n_head * c, + out_dim, + initializer='zero', + use_bias=(not last_bias_fuse)) + + def forward(self, in_data, nonbatched_bias=None): + """ + :param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim] + :param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv] + :param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv] + """ + + # qkv = self.to_qkv(in_data).chunk(3, dim=-1) + # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv) + + q = self.to_q(in_data) + k = self.to_k(in_data) + v = self.to_k(in_data) + + # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), + # [q, k, v]) + q, k, v = map(lambda t: t.view(t.shape[0], t.shape[1], t.shape[2], self.n_head, -1).permute(0, 1, 3, 2, 4), + [q, k, v]) + + q = q * self.scaling + + logits = torch.matmul(q, k.transpose(-1, -2)) + + if nonbatched_bias is not None: + logits += nonbatched_bias.unsqueeze(1) + weights = torch.softmax(logits, dim=-1) + # weights = softmax(logits) + + weighted_avg = torch.matmul(weights, v) + # weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)') + weighted_avg = weighted_avg.permute(0, 1, 3, 2, 4) + weighted_avg = weighted_avg.reshape(weighted_avg.shape[0], weighted_avg.shape[1], weighted_avg.shape[2], -1) + + if self.gating: + gate_values = self.gating_linear(in_data) + weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg) + + output = self.o_linear(weighted_avg) + return output diff --git a/evoformer/triangle.py b/evoformer/triangle.py new file mode 100644 index 000000000..7db0482f5 --- /dev/null +++ b/evoformer/triangle.py @@ -0,0 +1,192 @@ +import math + +import torch +import torch.nn as nn +from torch.nn import LayerNorm + +from .kernel import bias_dropout_add, bias_ele_dropout_residual +from .ops import Linear, SelfAttention, Transition + + +def permute_final_dims(tensor, inds): + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + return tensor.permute(first_inds + [zero_index + i for i in inds]) + + +class TriangleMultiplicationOutgoing(nn.Module): + + def __init__(self, d_pair, p_drop, c=128): + super(TriangleMultiplicationOutgoing, self).__init__() + self.d_pair = d_pair + self.c = c + + self.layernorm1 = LayerNorm(d_pair) + self.left_projection = Linear(d_pair, c) + self.right_projection = Linear(d_pair, c) + self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) + self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) + + self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.) + self.layernorm2 = LayerNorm(c) + self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False) + self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) + + self.p_drop = p_drop + + def forward(self, Z_raw): + Z = self.layernorm1(Z_raw) + left_proj_act = self.left_projection(Z) + right_proj_act = self.right_projection(Z) + + left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z)) + right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z)) + + g = torch.sigmoid(self.output_gate(Z)) + # p = torch.matmul( + # permute_final_dims(left_proj_act, (2, 0, 1)), + # permute_final_dims(right_proj_act, (2, 1, 0)), + # ) + # ab = permute_final_dims(p, (1, 2, 0)) + + ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act) + ab = self.output_projection(self.layernorm2(ab)) + dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype) + return bias_ele_dropout_residual(ab, + self.output_bias, + g, + dropout_mask, + Z_raw, + prob=self.p_drop) + + +class TriangleMultiplicationIncoming(nn.Module): + + def __init__(self, d_pair, p_drop, c=128): + super(TriangleMultiplicationIncoming, self).__init__() + self.d_pair = d_pair + self.c = c + + self.layernorm1 = LayerNorm(d_pair) + self.left_projection = Linear(d_pair, c) + self.right_projection = Linear(d_pair, c) + self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) + self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.) + + self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.) + self.layernorm2 = LayerNorm(c) + self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False) + self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) + + self.p_drop = p_drop + + def forward(self, Z_raw): + Z = self.layernorm1(Z_raw) + left_proj_act = self.left_projection(Z) + right_proj_act = self.right_projection(Z) + + left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z)) + right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z)) + + g = torch.sigmoid(self.output_gate(Z)) + # p = torch.matmul( + # permute_final_dims(left_proj_act, (2, 1, 0)), + # permute_final_dims(right_proj_act, (2, 0, 1)), + # ) + # ab = permute_final_dims(p, (1, 2, 0)) + + ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act) + ab = self.output_projection(self.layernorm2(ab)) + dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype) + return bias_ele_dropout_residual(ab, + self.output_bias, + g, + dropout_mask, + Z_raw, + prob=self.p_drop) + + +class TriangleAttentionStartingNode(nn.Module): + + def __init__(self, d_pair, p_drop, c=32, n_head=4): + super(TriangleAttentionStartingNode, self).__init__() + self.d_pair = d_pair + self.c = c + self.n_head = n_head + self.p_drop = p_drop + + self.layernorm1 = LayerNorm(d_pair) + _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]), + std=1.0 / math.sqrt(d_pair)) + self.linear_b_weights = nn.parameter.Parameter(data=_init_weights) + self.attention = SelfAttention(qkv_dim=d_pair, + c=c, + n_head=n_head, + out_dim=d_pair, + gating=True, + last_bias_fuse=True) + + self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) + + def forward(self, Z_raw): + Z = self.layernorm1(Z_raw) + b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights) + + Z = self.attention(Z, b) + + dropout_mask = torch.ones_like(Z[:, 0:1, :, :], device=Z.device, dtype=Z.dtype) + return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) + + +class TriangleAttentionEndingNode(nn.Module): + + def __init__(self, d_pair, p_drop, c=32, n_head=4): + super(TriangleAttentionEndingNode, self).__init__() + self.d_pair = d_pair + self.c = c + self.n_head = n_head + self.p_drop = p_drop + + self.layernorm1 = LayerNorm(d_pair) + _init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]), + std=1.0 / math.sqrt(d_pair)) + self.linear_b_weights = nn.parameter.Parameter(data=_init_weights) + self.attention = SelfAttention(qkv_dim=d_pair, + c=c, + n_head=n_head, + out_dim=d_pair, + gating=True, + last_bias_fuse=True) + + self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True) + + def forward(self, Z_raw): + Z = Z_raw.transpose(-2, -3) + Z = self.layernorm1(Z) + b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights) + + Z = self.attention(Z, b) + + Z = Z.transpose(-2, -3) + dropout_mask = torch.ones_like(Z[:, :, 0:1, :], device=Z.device, dtype=Z.dtype) + return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop) + + +class PairStack(nn.Module): + + def __init__(self, d_pair, p_drop=0.25): + super(PairStack, self).__init__() + + self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, p_drop=p_drop) + self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, p_drop=p_drop) + self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, p_drop=p_drop) + self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, p_drop=p_drop) + self.PairTransition = Transition(d=d_pair) + + def forward(self, pair): + pair = self.TriangleMultiplicationOutgoing(pair) + pair = self.TriangleMultiplicationIncoming(pair) + pair = self.TriangleAttentionStartingNode(pair) + pair = self.TriangleAttentionEndingNode(pair) + pair = self.PairTransition(pair) + return pair