mirror of https://github.com/hpcaitech/ColossalAI
update openfold
parent
289f3a45c2
commit
5c4df01af3
|
@ -182,33 +182,28 @@ class EvoformerBlockCore(nn.Module):
|
|||
self,
|
||||
m: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
msa_mask: torch.Tensor,
|
||||
pair_mask: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
_mask_trans: bool = True,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# DeepMind doesn't mask these transitions in the source, so _mask_trans
|
||||
# should be disabled to better approximate the exact activations of
|
||||
# the original.
|
||||
msa_trans_mask = msa_mask if _mask_trans else None
|
||||
pair_trans_mask = pair_mask if _mask_trans else None
|
||||
|
||||
m = m + self.msa_transition(
|
||||
m, mask=msa_trans_mask, chunk_size=chunk_size
|
||||
m, chunk_size=chunk_size
|
||||
)
|
||||
z = z + self.outer_product_mean(
|
||||
m, mask=msa_mask, chunk_size=chunk_size
|
||||
m, chunk_size=chunk_size
|
||||
)
|
||||
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask))
|
||||
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask))
|
||||
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z))
|
||||
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z))
|
||||
z = z + self.ps_dropout_row_layer(
|
||||
self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size)
|
||||
self.tri_att_start(z, chunk_size=chunk_size)
|
||||
)
|
||||
z = z + self.ps_dropout_col_layer(
|
||||
self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size)
|
||||
self.tri_att_end(z, chunk_size=chunk_size)
|
||||
)
|
||||
z = z + self.pair_transition(
|
||||
z, mask=pair_trans_mask, chunk_size=chunk_size
|
||||
z, chunk_size=chunk_size
|
||||
)
|
||||
|
||||
return m, z
|
||||
|
@ -274,22 +269,16 @@ class EvoformerBlock(nn.Module):
|
|||
def forward(self,
|
||||
m: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
msa_mask: torch.Tensor,
|
||||
pair_mask: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
_mask_trans: bool = True,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
m = m + self.msa_dropout_layer(
|
||||
self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
|
||||
self.msa_att_row(m, z=z, chunk_size=chunk_size)
|
||||
)
|
||||
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
|
||||
m = m + self.msa_att_col(m, chunk_size=chunk_size)
|
||||
m, z = self.core(
|
||||
m,
|
||||
z,
|
||||
msa_mask=msa_mask,
|
||||
pair_mask=pair_mask,
|
||||
chunk_size=chunk_size,
|
||||
_mask_trans=_mask_trans,
|
||||
)
|
||||
|
||||
return m, z
|
||||
|
|
|
@ -136,45 +136,6 @@ class MSAAttention(nn.Module):
|
|||
|
||||
return m, mask_bias, z
|
||||
|
||||
@torch.jit.ignore
|
||||
def _chunked_msa_attn(self,
|
||||
m: torch.Tensor,
|
||||
z: Optional[torch.Tensor],
|
||||
mask: Optional[torch.Tensor],
|
||||
chunk_logits: int,
|
||||
checkpoint: bool,
|
||||
) -> torch.Tensor:
|
||||
MSA_DIM = -4
|
||||
|
||||
def _get_qkv(m, z):
|
||||
m, mask_bias, z = self._prep_inputs(m, z, mask)
|
||||
q, k, v = self.mha._prep_qkv(m, m)
|
||||
return m, q, k, v, mask_bias, z
|
||||
|
||||
checkpoint_fn = get_checkpoint_fn()
|
||||
|
||||
if(torch.is_grad_enabled() and checkpoint):
|
||||
m, q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z)
|
||||
else:
|
||||
m, q, k, v, mask_bias, z = _get_qkv(m, z)
|
||||
|
||||
o = _attention_chunked_trainable(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
biases=[mask_bias, z],
|
||||
chunk_size=chunk_logits,
|
||||
chunk_dim=MSA_DIM,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
if(torch.is_grad_enabled() and checkpoint):
|
||||
# Storing an additional m here is far from ideal
|
||||
m = checkpoint_fn(self.mha._wrap_up, o, m)
|
||||
else:
|
||||
m = self.mha._wrap_up(o, m)
|
||||
|
||||
return m
|
||||
|
||||
def forward(self,
|
||||
m: torch.Tensor,
|
||||
|
@ -199,12 +160,6 @@ class MSAAttention(nn.Module):
|
|||
cost of slower execution. Chunking is not performed by default.
|
||||
|
||||
"""
|
||||
if(_chunk_logits is not None):
|
||||
return self._chunked_msa_attn(
|
||||
m=m, z=z, mask=mask,
|
||||
chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks
|
||||
)
|
||||
|
||||
m, mask_bias, z = self._prep_inputs(m, z, mask)
|
||||
|
||||
biases = [mask_bias]
|
||||
|
@ -306,15 +261,11 @@ class MSAColumnAttention(nn.Module):
|
|||
"""
|
||||
# [*, N_res, N_seq, C_in]
|
||||
m = m.transpose(-2, -3)
|
||||
if mask is not None:
|
||||
mask = mask.transpose(-1, -2)
|
||||
|
||||
m = self._msa_att(m, mask=mask, chunk_size=chunk_size)
|
||||
m = self._msa_att(m, chunk_size=chunk_size)
|
||||
|
||||
# [*, N_seq, N_res, C_in]
|
||||
m = m.transpose(-2, -3)
|
||||
if mask is not None:
|
||||
mask = mask.transpose(-1, -2)
|
||||
|
||||
return m
|
||||
|
||||
|
@ -344,12 +295,10 @@ class MSAColumnGlobalAttention(nn.Module):
|
|||
@torch.jit.ignore
|
||||
def _chunk(self,
|
||||
m: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
chunk_size: int,
|
||||
) -> torch.Tensor:
|
||||
mha_input = {
|
||||
"m": m,
|
||||
"mask": mask,
|
||||
}
|
||||
return chunk_layer(
|
||||
self.global_attention,
|
||||
|
@ -361,30 +310,20 @@ class MSAColumnGlobalAttention(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
m: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
n_seq, n_res, c_in = m.shape[-3:]
|
||||
|
||||
if mask is None:
|
||||
# [*, N_seq, N_res]
|
||||
mask = torch.ones(
|
||||
m.shape[:-1],
|
||||
dtype=m.dtype,
|
||||
device=m.device,
|
||||
).detach()
|
||||
|
||||
# [*, N_res, N_seq, C_in]
|
||||
m = m.transpose(-2, -3)
|
||||
mask = mask.transpose(-1, -2)
|
||||
|
||||
# [*, N_res, N_seq, C_in]
|
||||
m = self.layer_norm_m(m)
|
||||
|
||||
if chunk_size is not None:
|
||||
m = self._chunk(m, mask, chunk_size)
|
||||
m = self._chunk(m, chunk_size)
|
||||
else:
|
||||
m = self.global_attention(m=m, mask=mask)
|
||||
m = self.global_attention(m=m)
|
||||
|
||||
# [*, N_seq, N_res, C_in]
|
||||
m = m.transpose(-2, -3)
|
||||
|
|
Loading…
Reference in New Issue