update openfold

pull/2364/head
oahzxl 2022-12-29 15:54:08 +08:00
parent 289f3a45c2
commit 5c4df01af3
2 changed files with 12 additions and 84 deletions

View File

@ -182,33 +182,28 @@ class EvoformerBlockCore(nn.Module):
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# the original.
msa_trans_mask = msa_mask if _mask_trans else None
pair_trans_mask = pair_mask if _mask_trans else None
m = m + self.msa_transition(
m, mask=msa_trans_mask, chunk_size=chunk_size
m, chunk_size=chunk_size
)
z = z + self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size
m, chunk_size=chunk_size
)
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z))
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z))
z = z + self.ps_dropout_row_layer(
self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size)
self.tri_att_start(z, chunk_size=chunk_size)
)
z = z + self.ps_dropout_col_layer(
self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size)
self.tri_att_end(z, chunk_size=chunk_size)
)
z = z + self.pair_transition(
z, mask=pair_trans_mask, chunk_size=chunk_size
z, chunk_size=chunk_size
)
return m, z
@ -274,22 +269,16 @@ class EvoformerBlock(nn.Module):
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
m = m + self.msa_dropout_layer(
self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
self.msa_att_row(m, z=z, chunk_size=chunk_size)
)
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m = m + self.msa_att_col(m, chunk_size=chunk_size)
m, z = self.core(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
return m, z

View File

@ -136,45 +136,6 @@ class MSAAttention(nn.Module):
return m, mask_bias, z
@torch.jit.ignore
def _chunked_msa_attn(self,
m: torch.Tensor,
z: Optional[torch.Tensor],
mask: Optional[torch.Tensor],
chunk_logits: int,
checkpoint: bool,
) -> torch.Tensor:
MSA_DIM = -4
def _get_qkv(m, z):
m, mask_bias, z = self._prep_inputs(m, z, mask)
q, k, v = self.mha._prep_qkv(m, m)
return m, q, k, v, mask_bias, z
checkpoint_fn = get_checkpoint_fn()
if(torch.is_grad_enabled() and checkpoint):
m, q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z)
else:
m, q, k, v, mask_bias, z = _get_qkv(m, z)
o = _attention_chunked_trainable(
query=q,
key=k,
value=v,
biases=[mask_bias, z],
chunk_size=chunk_logits,
chunk_dim=MSA_DIM,
checkpoint=checkpoint,
)
if(torch.is_grad_enabled() and checkpoint):
# Storing an additional m here is far from ideal
m = checkpoint_fn(self.mha._wrap_up, o, m)
else:
m = self.mha._wrap_up(o, m)
return m
def forward(self,
m: torch.Tensor,
@ -199,12 +160,6 @@ class MSAAttention(nn.Module):
cost of slower execution. Chunking is not performed by default.
"""
if(_chunk_logits is not None):
return self._chunked_msa_attn(
m=m, z=z, mask=mask,
chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks
)
m, mask_bias, z = self._prep_inputs(m, z, mask)
biases = [mask_bias]
@ -306,15 +261,11 @@ class MSAColumnAttention(nn.Module):
"""
# [*, N_res, N_seq, C_in]
m = m.transpose(-2, -3)
if mask is not None:
mask = mask.transpose(-1, -2)
m = self._msa_att(m, mask=mask, chunk_size=chunk_size)
m = self._msa_att(m, chunk_size=chunk_size)
# [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3)
if mask is not None:
mask = mask.transpose(-1, -2)
return m
@ -344,12 +295,10 @@ class MSAColumnGlobalAttention(nn.Module):
@torch.jit.ignore
def _chunk(self,
m: torch.Tensor,
mask: torch.Tensor,
chunk_size: int,
) -> torch.Tensor:
mha_input = {
"m": m,
"mask": mask,
}
return chunk_layer(
self.global_attention,
@ -361,30 +310,20 @@ class MSAColumnGlobalAttention(nn.Module):
def forward(
self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
n_seq, n_res, c_in = m.shape[-3:]
if mask is None:
# [*, N_seq, N_res]
mask = torch.ones(
m.shape[:-1],
dtype=m.dtype,
device=m.device,
).detach()
# [*, N_res, N_seq, C_in]
m = m.transpose(-2, -3)
mask = mask.transpose(-1, -2)
# [*, N_res, N_seq, C_in]
m = self.layer_norm_m(m)
if chunk_size is not None:
m = self._chunk(m, mask, chunk_size)
m = self._chunk(m, chunk_size)
else:
m = self.global_attention(m=m, mask=mask)
m = self.global_attention(m=m)
# [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3)