mirror of https://github.com/hpcaitech/ColossalAI
432 lines
12 KiB
Python
432 lines
12 KiB
Python
# Copyright 2021 AlQuraishi Laboratory
|
|
# Copyright 2021 DeepMind Technologies Limited
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
from typing import Tuple, Optional
|
|
from functools import partial
|
|
|
|
from .primitives import Linear, LayerNorm
|
|
from .dropout import DropoutRowwise, DropoutColumnwise
|
|
from .msa import (
|
|
MSARowAttentionWithPairBias,
|
|
MSAColumnAttention,
|
|
MSAColumnGlobalAttention,
|
|
)
|
|
from .outer_product_mean import OuterProductMean
|
|
from .pair_transition import PairTransition
|
|
from .triangular_attention import (
|
|
TriangleAttentionStartingNode,
|
|
TriangleAttentionEndingNode,
|
|
)
|
|
from .triangular_multiplicative_update import (
|
|
TriangleMultiplicationOutgoing,
|
|
TriangleMultiplicationIncoming,
|
|
)
|
|
from .checkpointing import checkpoint_blocks, get_checkpoint_fn
|
|
from .tensor_utils import chunk_layer
|
|
|
|
|
|
class MSATransition(nn.Module):
|
|
"""
|
|
Feed-forward network applied to MSA activations after attention.
|
|
|
|
Implements Algorithm 9
|
|
"""
|
|
def __init__(self, c_m, n):
|
|
"""
|
|
Args:
|
|
c_m:
|
|
MSA channel dimension
|
|
n:
|
|
Factor multiplied to c_m to obtain the hidden channel
|
|
dimension
|
|
"""
|
|
super(MSATransition, self).__init__()
|
|
|
|
self.c_m = c_m
|
|
self.n = n
|
|
|
|
self.layer_norm = LayerNorm(self.c_m)
|
|
self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu")
|
|
self.relu = nn.ReLU()
|
|
self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
|
|
|
|
def _transition(self, m, mask):
|
|
m = self.linear_1(m)
|
|
m = self.relu(m)
|
|
m = self.linear_2(m) * mask
|
|
return m
|
|
|
|
@torch.jit.ignore
|
|
def _chunk(self,
|
|
m: torch.Tensor,
|
|
mask: torch.Tensor,
|
|
chunk_size: int,
|
|
) -> torch.Tensor:
|
|
return chunk_layer(
|
|
self._transition,
|
|
{"m": m, "mask": mask},
|
|
chunk_size=chunk_size,
|
|
no_batch_dims=len(m.shape[:-2]),
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
m: torch.Tensor,
|
|
mask: Optional[torch.Tensor] = None,
|
|
chunk_size: Optional[int] = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Args:
|
|
m:
|
|
[*, N_seq, N_res, C_m] MSA activation
|
|
mask:
|
|
[*, N_seq, N_res, C_m] MSA mask
|
|
Returns:
|
|
m:
|
|
[*, N_seq, N_res, C_m] MSA activation update
|
|
"""
|
|
|
|
# DISCREPANCY: DeepMind forgets to apply the MSA mask here.
|
|
if mask is None:
|
|
mask = m.new_ones(m.shape[:-1])
|
|
|
|
# [*, N_seq, N_res, 1]
|
|
mask = mask.unsqueeze(-1)
|
|
|
|
m = self.layer_norm(m)
|
|
|
|
if chunk_size is not None:
|
|
m = self._chunk(m, mask, chunk_size)
|
|
else:
|
|
m = self._transition(m, mask)
|
|
|
|
return m
|
|
|
|
|
|
class EvoformerBlockCore(nn.Module):
|
|
def __init__(
|
|
self,
|
|
c_m: int,
|
|
c_z: int,
|
|
c_hidden_opm: int,
|
|
c_hidden_mul: int,
|
|
c_hidden_pair_att: int,
|
|
no_heads_msa: int,
|
|
no_heads_pair: int,
|
|
transition_n: int,
|
|
pair_dropout: float,
|
|
inf: float,
|
|
eps: float,
|
|
_is_extra_msa_stack: bool = False,
|
|
is_multimer: bool = False,
|
|
):
|
|
super(EvoformerBlockCore, self).__init__()
|
|
self.is_multimer = is_multimer
|
|
self.msa_transition = MSATransition(
|
|
c_m=c_m,
|
|
n=transition_n,
|
|
)
|
|
|
|
self.outer_product_mean = OuterProductMean(
|
|
c_m,
|
|
c_z,
|
|
c_hidden_opm,
|
|
)
|
|
|
|
self.tri_mul_out = TriangleMultiplicationOutgoing(
|
|
c_z,
|
|
c_hidden_mul,
|
|
)
|
|
self.tri_mul_in = TriangleMultiplicationIncoming(
|
|
c_z,
|
|
c_hidden_mul,
|
|
)
|
|
|
|
self.tri_att_start = TriangleAttentionStartingNode(
|
|
c_z,
|
|
c_hidden_pair_att,
|
|
no_heads_pair,
|
|
inf=inf,
|
|
)
|
|
self.tri_att_end = TriangleAttentionEndingNode(
|
|
c_z,
|
|
c_hidden_pair_att,
|
|
no_heads_pair,
|
|
inf=inf,
|
|
)
|
|
|
|
self.pair_transition = PairTransition(
|
|
c_z,
|
|
transition_n,
|
|
)
|
|
|
|
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
|
|
self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout)
|
|
|
|
def forward(
|
|
self,
|
|
m: torch.Tensor,
|
|
z: torch.Tensor,
|
|
chunk_size: Optional[int] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
# DeepMind doesn't mask these transitions in the source, so _mask_trans
|
|
# should be disabled to better approximate the exact activations of
|
|
# the original.
|
|
|
|
m = m + self.msa_transition(
|
|
m, chunk_size=chunk_size
|
|
)
|
|
z = z + self.outer_product_mean(
|
|
m, chunk_size=chunk_size
|
|
)
|
|
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z))
|
|
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z))
|
|
z = z + self.ps_dropout_row_layer(
|
|
self.tri_att_start(z, chunk_size=chunk_size)
|
|
)
|
|
z = z + self.ps_dropout_col_layer(
|
|
self.tri_att_end(z, chunk_size=chunk_size)
|
|
)
|
|
z = z + self.pair_transition(
|
|
z, chunk_size=chunk_size
|
|
)
|
|
|
|
return m, z
|
|
|
|
|
|
class EvoformerBlock(nn.Module):
|
|
def __init__(self,
|
|
c_m: int,
|
|
c_z: int,
|
|
c_hidden_msa_att: int,
|
|
c_hidden_opm: int,
|
|
c_hidden_mul: int,
|
|
c_hidden_pair_att: int,
|
|
no_heads_msa: int,
|
|
no_heads_pair: int,
|
|
transition_n: int,
|
|
msa_dropout: float,
|
|
pair_dropout: float,
|
|
inf: float,
|
|
eps: float,
|
|
is_multimer: bool,
|
|
):
|
|
super(EvoformerBlock, self).__init__()
|
|
|
|
self.msa_att_row = MSARowAttentionWithPairBias(
|
|
c_m=c_m,
|
|
c_z=c_z,
|
|
c_hidden=c_hidden_msa_att,
|
|
no_heads=no_heads_msa,
|
|
inf=inf,
|
|
)
|
|
|
|
self.msa_att_col = MSAColumnAttention(
|
|
c_m,
|
|
c_hidden_msa_att,
|
|
no_heads_msa,
|
|
inf=inf,
|
|
)
|
|
|
|
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
|
|
|
|
self.core = EvoformerBlockCore(
|
|
c_m=c_m,
|
|
c_z=c_z,
|
|
c_hidden_opm=c_hidden_opm,
|
|
c_hidden_mul=c_hidden_mul,
|
|
c_hidden_pair_att=c_hidden_pair_att,
|
|
no_heads_msa=no_heads_msa,
|
|
no_heads_pair=no_heads_pair,
|
|
transition_n=transition_n,
|
|
pair_dropout=pair_dropout,
|
|
inf=inf,
|
|
eps=eps,
|
|
)
|
|
|
|
self.outer_product_mean = OuterProductMean(
|
|
c_m,
|
|
c_z,
|
|
c_hidden_opm,
|
|
)
|
|
self.is_multimer = is_multimer
|
|
|
|
def forward(self,
|
|
m: torch.Tensor,
|
|
z: torch.Tensor,
|
|
chunk_size: Optional[int] = None,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
m = m + self.msa_dropout_layer(
|
|
self.msa_att_row(m, z=z, chunk_size=chunk_size)
|
|
)
|
|
m = m + self.msa_att_col(m, chunk_size=chunk_size)
|
|
m, z = self.core(
|
|
m,
|
|
z,
|
|
chunk_size=chunk_size,
|
|
)
|
|
|
|
return m, z
|
|
|
|
|
|
class EvoformerStack(nn.Module):
|
|
"""
|
|
Main Evoformer trunk.
|
|
|
|
Implements Algorithm 6.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
c_m: int,
|
|
c_z: int,
|
|
c_hidden_msa_att: int,
|
|
c_hidden_opm: int,
|
|
c_hidden_mul: int,
|
|
c_hidden_pair_att: int,
|
|
c_s: int,
|
|
no_heads_msa: int,
|
|
no_heads_pair: int,
|
|
no_blocks: int,
|
|
transition_n: int,
|
|
msa_dropout: float,
|
|
pair_dropout: float,
|
|
blocks_per_ckpt: int,
|
|
inf: float,
|
|
eps: float,
|
|
clear_cache_between_blocks: bool = False,
|
|
is_multimer: bool = False,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Args:
|
|
c_m:
|
|
MSA channel dimension
|
|
c_z:
|
|
Pair channel dimension
|
|
c_hidden_msa_att:
|
|
Hidden dimension in MSA attention
|
|
c_hidden_opm:
|
|
Hidden dimension in outer product mean module
|
|
c_hidden_mul:
|
|
Hidden dimension in multiplicative updates
|
|
c_hidden_pair_att:
|
|
Hidden dimension in triangular attention
|
|
c_s:
|
|
Channel dimension of the output "single" embedding
|
|
no_heads_msa:
|
|
Number of heads used for MSA attention
|
|
no_heads_pair:
|
|
Number of heads used for pair attention
|
|
no_blocks:
|
|
Number of Evoformer blocks in the stack
|
|
transition_n:
|
|
Factor by which to multiply c_m to obtain the MSATransition
|
|
hidden dimension
|
|
msa_dropout:
|
|
Dropout rate for MSA activations
|
|
pair_dropout:
|
|
Dropout used for pair activations
|
|
blocks_per_ckpt:
|
|
Number of Evoformer blocks in each activation checkpoint
|
|
clear_cache_between_blocks:
|
|
Whether to clear CUDA's GPU memory cache between blocks of the
|
|
stack. Slows down each block but can reduce fragmentation
|
|
"""
|
|
super(EvoformerStack, self).__init__()
|
|
|
|
self.blocks_per_ckpt = blocks_per_ckpt
|
|
self.clear_cache_between_blocks = clear_cache_between_blocks
|
|
|
|
self.blocks = nn.ModuleList()
|
|
|
|
for _ in range(no_blocks):
|
|
block = EvoformerBlock(
|
|
c_m=c_m,
|
|
c_z=c_z,
|
|
c_hidden_msa_att=c_hidden_msa_att,
|
|
c_hidden_opm=c_hidden_opm,
|
|
c_hidden_mul=c_hidden_mul,
|
|
c_hidden_pair_att=c_hidden_pair_att,
|
|
no_heads_msa=no_heads_msa,
|
|
no_heads_pair=no_heads_pair,
|
|
transition_n=transition_n,
|
|
msa_dropout=msa_dropout,
|
|
pair_dropout=pair_dropout,
|
|
inf=inf,
|
|
eps=eps,
|
|
is_multimer=is_multimer,
|
|
)
|
|
self.blocks.append(block)
|
|
|
|
self.linear = Linear(c_m, c_s)
|
|
|
|
def forward(self,
|
|
m: torch.Tensor,
|
|
z: torch.Tensor,
|
|
msa_mask: torch.Tensor,
|
|
pair_mask: torch.Tensor,
|
|
chunk_size: int,
|
|
_mask_trans: bool = True,
|
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
|
"""
|
|
Args:
|
|
m:
|
|
[*, N_seq, N_res, C_m] MSA embedding
|
|
z:
|
|
[*, N_res, N_res, C_z] pair embedding
|
|
msa_mask:
|
|
[*, N_seq, N_res] MSA mask
|
|
pair_mask:
|
|
[*, N_res, N_res] pair mask
|
|
Returns:
|
|
m:
|
|
[*, N_seq, N_res, C_m] MSA embedding
|
|
z:
|
|
[*, N_res, N_res, C_z] pair embedding
|
|
s:
|
|
[*, N_res, C_s] single embedding (or None if extra MSA stack)
|
|
"""
|
|
blocks = [
|
|
partial(
|
|
b,
|
|
msa_mask=msa_mask,
|
|
pair_mask=pair_mask,
|
|
chunk_size=chunk_size,
|
|
_mask_trans=_mask_trans,
|
|
)
|
|
for b in self.blocks
|
|
]
|
|
|
|
if(self.clear_cache_between_blocks):
|
|
def block_with_cache_clear(block, *args):
|
|
torch.cuda.empty_cache()
|
|
return block(*args)
|
|
|
|
blocks = [partial(block_with_cache_clear, b) for b in blocks]
|
|
|
|
m, z = checkpoint_blocks(
|
|
blocks,
|
|
args=(m, z),
|
|
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
|
|
)
|
|
|
|
s = self.linear(m[..., 0, :, :])
|
|
|
|
return m, z, s
|