mirror of https://github.com/hpcaitech/ColossalAI
128 lines
3.7 KiB
Python
128 lines
3.7 KiB
Python
|
# Copyright 2021 AlQuraishi Laboratory
|
||
|
# Copyright 2021 DeepMind Technologies Limited
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
from functools import partialmethod
|
||
|
from typing import Optional
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
from .primitives import Linear, LayerNorm
|
||
|
from .tensor_utils import permute_final_dims
|
||
|
|
||
|
|
||
|
class TriangleMultiplicativeUpdate(nn.Module):
|
||
|
"""
|
||
|
Implements Algorithms 11 and 12.
|
||
|
"""
|
||
|
def __init__(self, c_z, c_hidden, _outgoing=True):
|
||
|
"""
|
||
|
Args:
|
||
|
c_z:
|
||
|
Input channel dimension
|
||
|
c:
|
||
|
Hidden channel dimension
|
||
|
"""
|
||
|
super(TriangleMultiplicativeUpdate, self).__init__()
|
||
|
self.c_z = c_z
|
||
|
self.c_hidden = c_hidden
|
||
|
self._outgoing = _outgoing
|
||
|
|
||
|
self.linear_a_p = Linear(self.c_z, self.c_hidden)
|
||
|
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
|
||
|
self.linear_b_p = Linear(self.c_z, self.c_hidden)
|
||
|
self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
|
||
|
self.linear_g = Linear(self.c_z, self.c_z, init="gating")
|
||
|
self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
|
||
|
|
||
|
self.layer_norm_in = LayerNorm(self.c_z)
|
||
|
self.layer_norm_out = LayerNorm(self.c_hidden)
|
||
|
|
||
|
self.sigmoid = nn.Sigmoid()
|
||
|
|
||
|
def _combine_projections(self,
|
||
|
a: torch.Tensor,
|
||
|
b: torch.Tensor,
|
||
|
) -> torch.Tensor:
|
||
|
raise NotImplementedError("This method needs to be overridden")
|
||
|
|
||
|
def forward(self,
|
||
|
z: torch.Tensor,
|
||
|
mask: Optional[torch.Tensor] = None
|
||
|
) -> torch.Tensor:
|
||
|
"""
|
||
|
Args:
|
||
|
x:
|
||
|
[*, N_res, N_res, C_z] input tensor
|
||
|
mask:
|
||
|
[*, N_res, N_res] input mask
|
||
|
Returns:
|
||
|
[*, N_res, N_res, C_z] output tensor
|
||
|
"""
|
||
|
if mask is None:
|
||
|
mask = z.new_ones(z.shape[:-1])
|
||
|
|
||
|
mask = mask.unsqueeze(-1)
|
||
|
|
||
|
z = self.layer_norm_in(z)
|
||
|
a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z))
|
||
|
a = a * mask
|
||
|
b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z))
|
||
|
b = b * mask
|
||
|
x = self._combine_projections(a, b)
|
||
|
x = self.layer_norm_out(x)
|
||
|
x = self.linear_z(x)
|
||
|
g = self.sigmoid(self.linear_g(z))
|
||
|
z = x * g
|
||
|
|
||
|
return z
|
||
|
|
||
|
|
||
|
class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
|
||
|
"""
|
||
|
Implements Algorithm 11.
|
||
|
"""
|
||
|
def _combine_projections(self,
|
||
|
a: torch.Tensor, # [*, N_i, N_k, C]
|
||
|
b: torch.Tensor, # [*, N_j, N_k, C]
|
||
|
):
|
||
|
# [*, C, N_i, N_j]
|
||
|
p = torch.matmul(
|
||
|
permute_final_dims(a, (2, 0, 1)),
|
||
|
permute_final_dims(b, (2, 1, 0)),
|
||
|
)
|
||
|
|
||
|
# [*, N_i, N_j, C]
|
||
|
return permute_final_dims(p, (1, 2, 0))
|
||
|
|
||
|
|
||
|
class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
|
||
|
"""
|
||
|
Implements Algorithm 12.
|
||
|
"""
|
||
|
def _combine_projections(self,
|
||
|
a: torch.Tensor, # [*, N_k, N_i, C]
|
||
|
b: torch.Tensor, # [*, N_k, N_j, C]
|
||
|
):
|
||
|
# [*, C, N_i, N_j]
|
||
|
p = torch.matmul(
|
||
|
permute_final_dims(a, (2, 1, 0)),
|
||
|
permute_final_dims(b, (2, 0, 1)),
|
||
|
)
|
||
|
|
||
|
# [*, N_i, N_j, C]
|
||
|
return permute_final_dims(p, (1, 2, 0))
|
||
|
|