diff --git a/openfold/evoformer.py b/openfold/evoformer.py index 21e422b04..7fbcd8a76 100644 --- a/openfold/evoformer.py +++ b/openfold/evoformer.py @@ -182,33 +182,28 @@ class EvoformerBlockCore(nn.Module): self, m: torch.Tensor, z: torch.Tensor, - msa_mask: torch.Tensor, - pair_mask: torch.Tensor, chunk_size: Optional[int] = None, - _mask_trans: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: # DeepMind doesn't mask these transitions in the source, so _mask_trans # should be disabled to better approximate the exact activations of # the original. - msa_trans_mask = msa_mask if _mask_trans else None - pair_trans_mask = pair_mask if _mask_trans else None m = m + self.msa_transition( - m, mask=msa_trans_mask, chunk_size=chunk_size + m, chunk_size=chunk_size ) z = z + self.outer_product_mean( - m, mask=msa_mask, chunk_size=chunk_size + m, chunk_size=chunk_size ) - z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask)) - z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask)) + z = z + self.ps_dropout_row_layer(self.tri_mul_out(z)) + z = z + self.ps_dropout_row_layer(self.tri_mul_in(z)) z = z + self.ps_dropout_row_layer( - self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size) + self.tri_att_start(z, chunk_size=chunk_size) ) z = z + self.ps_dropout_col_layer( - self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size) + self.tri_att_end(z, chunk_size=chunk_size) ) z = z + self.pair_transition( - z, mask=pair_trans_mask, chunk_size=chunk_size + z, chunk_size=chunk_size ) return m, z @@ -274,22 +269,16 @@ class EvoformerBlock(nn.Module): def forward(self, m: torch.Tensor, z: torch.Tensor, - msa_mask: torch.Tensor, - pair_mask: torch.Tensor, chunk_size: Optional[int] = None, - _mask_trans: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: m = m + self.msa_dropout_layer( - self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size) + self.msa_att_row(m, z=z, chunk_size=chunk_size) ) - m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size) + m = m + self.msa_att_col(m, chunk_size=chunk_size) m, z = self.core( m, z, - msa_mask=msa_mask, - pair_mask=pair_mask, chunk_size=chunk_size, - _mask_trans=_mask_trans, ) return m, z diff --git a/openfold/msa.py b/openfold/msa.py index 172b26def..00b822e7f 100644 --- a/openfold/msa.py +++ b/openfold/msa.py @@ -136,45 +136,6 @@ class MSAAttention(nn.Module): return m, mask_bias, z - @torch.jit.ignore - def _chunked_msa_attn(self, - m: torch.Tensor, - z: Optional[torch.Tensor], - mask: Optional[torch.Tensor], - chunk_logits: int, - checkpoint: bool, - ) -> torch.Tensor: - MSA_DIM = -4 - - def _get_qkv(m, z): - m, mask_bias, z = self._prep_inputs(m, z, mask) - q, k, v = self.mha._prep_qkv(m, m) - return m, q, k, v, mask_bias, z - - checkpoint_fn = get_checkpoint_fn() - - if(torch.is_grad_enabled() and checkpoint): - m, q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z) - else: - m, q, k, v, mask_bias, z = _get_qkv(m, z) - - o = _attention_chunked_trainable( - query=q, - key=k, - value=v, - biases=[mask_bias, z], - chunk_size=chunk_logits, - chunk_dim=MSA_DIM, - checkpoint=checkpoint, - ) - - if(torch.is_grad_enabled() and checkpoint): - # Storing an additional m here is far from ideal - m = checkpoint_fn(self.mha._wrap_up, o, m) - else: - m = self.mha._wrap_up(o, m) - - return m def forward(self, m: torch.Tensor, @@ -199,12 +160,6 @@ class MSAAttention(nn.Module): cost of slower execution. Chunking is not performed by default. """ - if(_chunk_logits is not None): - return self._chunked_msa_attn( - m=m, z=z, mask=mask, - chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks - ) - m, mask_bias, z = self._prep_inputs(m, z, mask) biases = [mask_bias] @@ -306,15 +261,11 @@ class MSAColumnAttention(nn.Module): """ # [*, N_res, N_seq, C_in] m = m.transpose(-2, -3) - if mask is not None: - mask = mask.transpose(-1, -2) - m = self._msa_att(m, mask=mask, chunk_size=chunk_size) + m = self._msa_att(m, chunk_size=chunk_size) # [*, N_seq, N_res, C_in] m = m.transpose(-2, -3) - if mask is not None: - mask = mask.transpose(-1, -2) return m @@ -344,12 +295,10 @@ class MSAColumnGlobalAttention(nn.Module): @torch.jit.ignore def _chunk(self, m: torch.Tensor, - mask: torch.Tensor, chunk_size: int, ) -> torch.Tensor: mha_input = { "m": m, - "mask": mask, } return chunk_layer( self.global_attention, @@ -361,30 +310,20 @@ class MSAColumnGlobalAttention(nn.Module): def forward( self, m: torch.Tensor, - mask: Optional[torch.Tensor] = None, chunk_size: Optional[int] = None, ) -> torch.Tensor: n_seq, n_res, c_in = m.shape[-3:] - if mask is None: - # [*, N_seq, N_res] - mask = torch.ones( - m.shape[:-1], - dtype=m.dtype, - device=m.device, - ).detach() - # [*, N_res, N_seq, C_in] m = m.transpose(-2, -3) - mask = mask.transpose(-1, -2) # [*, N_res, N_seq, C_in] m = self.layer_norm_m(m) if chunk_size is not None: - m = self._chunk(m, mask, chunk_size) + m = self._chunk(m, chunk_size) else: - m = self.global_attention(m=m, mask=mask) + m = self.global_attention(m=m) # [*, N_seq, N_res, C_in] m = m.transpose(-2, -3)