mirror of https://github.com/hpcaitech/ColossalAI
100 lines
2.7 KiB
Python
100 lines
2.7 KiB
Python
|
# Copyright 2021 AlQuraishi Laboratory
|
||
|
# Copyright 2021 DeepMind Technologies Limited
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
from typing import Optional
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from .primitives import Linear, LayerNorm
|
||
|
from .tensor_utils import chunk_layer
|
||
|
|
||
|
|
||
|
class PairTransition(nn.Module):
|
||
|
"""
|
||
|
Implements Algorithm 15.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, c_z, n):
|
||
|
"""
|
||
|
Args:
|
||
|
c_z:
|
||
|
Pair transition channel dimension
|
||
|
n:
|
||
|
Factor by which c_z is multiplied to obtain hidden channel
|
||
|
dimension
|
||
|
"""
|
||
|
super(PairTransition, self).__init__()
|
||
|
|
||
|
self.c_z = c_z
|
||
|
self.n = n
|
||
|
|
||
|
self.layer_norm = LayerNorm(self.c_z)
|
||
|
self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu")
|
||
|
self.relu = nn.ReLU()
|
||
|
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
|
||
|
|
||
|
def _transition(self, z, mask):
|
||
|
# [*, N_res, N_res, C_hidden]
|
||
|
z = self.linear_1(z)
|
||
|
z = self.relu(z)
|
||
|
|
||
|
# [*, N_res, N_res, C_z]
|
||
|
z = self.linear_2(z) * mask
|
||
|
|
||
|
return z
|
||
|
|
||
|
@torch.jit.ignore
|
||
|
def _chunk(self,
|
||
|
z: torch.Tensor,
|
||
|
mask: torch.Tensor,
|
||
|
chunk_size: int,
|
||
|
) -> torch.Tensor:
|
||
|
return chunk_layer(
|
||
|
self._transition,
|
||
|
{"z": z, "mask": mask},
|
||
|
chunk_size=chunk_size,
|
||
|
no_batch_dims=len(z.shape[:-2]),
|
||
|
)
|
||
|
|
||
|
|
||
|
def forward(self,
|
||
|
z: torch.Tensor,
|
||
|
mask: Optional[torch.Tensor] = None,
|
||
|
chunk_size: Optional[int] = None,
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Args:
|
||
|
z:
|
||
|
[*, N_res, N_res, C_z] pair embedding
|
||
|
Returns:
|
||
|
[*, N_res, N_res, C_z] pair embedding update
|
||
|
"""
|
||
|
# DISCREPANCY: DeepMind forgets to apply the mask in this module.
|
||
|
if mask is None:
|
||
|
mask = z.new_ones(z.shape[:-1])
|
||
|
|
||
|
# [*, N_res, N_res, 1]
|
||
|
mask = mask.unsqueeze(-1)
|
||
|
|
||
|
# [*, N_res, N_res, C_z]
|
||
|
z = self.layer_norm(z)
|
||
|
|
||
|
if chunk_size is not None:
|
||
|
z = self._chunk(z, mask, chunk_size)
|
||
|
else:
|
||
|
z = self._transition(z=z, mask=mask)
|
||
|
|
||
|
return z
|