ColossalAI/tests/test_autochunk/evoformer/triangle.py

193 lines
7.7 KiB
Python
Raw Normal View History

2022-11-02 07:12:08 +00:00
import math
import torch
import torch.nn as nn
from torch.nn import LayerNorm
from .kernel import bias_dropout_add, bias_ele_dropout_residual
from .ops import Linear, SelfAttention, Transition
def permute_final_dims(tensor, inds):
zero_index = -1 * len(inds)
first_inds = list(range(len(tensor.shape[:zero_index])))
return tensor.permute(first_inds + [zero_index + i for i in inds])
class TriangleMultiplicationOutgoing(nn.Module):
def __init__(self, d_pair, p_drop, c=128):
super(TriangleMultiplicationOutgoing, self).__init__()
self.d_pair = d_pair
self.c = c
self.layernorm1 = LayerNorm(d_pair)
self.left_projection = Linear(d_pair, c)
self.right_projection = Linear(d_pair, c)
self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
self.layernorm2 = LayerNorm(c)
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
self.p_drop = p_drop
def forward(self, Z_raw):
Z = self.layernorm1(Z_raw)
left_proj_act = self.left_projection(Z)
right_proj_act = self.right_projection(Z)
left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z))
right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z))
g = torch.sigmoid(self.output_gate(Z))
# p = torch.matmul(
# permute_final_dims(left_proj_act, (2, 0, 1)),
# permute_final_dims(right_proj_act, (2, 1, 0)),
# )
# ab = permute_final_dims(p, (1, 2, 0))
ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act)
ab = self.output_projection(self.layernorm2(ab))
2022-11-02 07:49:25 +00:00
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
2022-11-02 07:12:08 +00:00
return bias_ele_dropout_residual(ab,
self.output_bias,
g,
dropout_mask,
Z_raw,
prob=self.p_drop)
class TriangleMultiplicationIncoming(nn.Module):
def __init__(self, d_pair, p_drop, c=128):
super(TriangleMultiplicationIncoming, self).__init__()
self.d_pair = d_pair
self.c = c
self.layernorm1 = LayerNorm(d_pair)
self.left_projection = Linear(d_pair, c)
self.right_projection = Linear(d_pair, c)
self.left_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
self.right_gate = Linear(d_pair, c, initializer='zeros', bias_init=1.)
self.output_gate = Linear(d_pair, d_pair, initializer='zeros', bias_init=1.)
self.layernorm2 = LayerNorm(c)
self.output_projection = Linear(d_pair, d_pair, initializer='zeros', use_bias=False)
self.output_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
self.p_drop = p_drop
def forward(self, Z_raw):
Z = self.layernorm1(Z_raw)
left_proj_act = self.left_projection(Z)
right_proj_act = self.right_projection(Z)
left_proj_act = left_proj_act * torch.sigmoid(self.left_gate(Z))
right_proj_act = right_proj_act * torch.sigmoid(self.right_gate(Z))
g = torch.sigmoid(self.output_gate(Z))
# p = torch.matmul(
# permute_final_dims(left_proj_act, (2, 1, 0)),
# permute_final_dims(right_proj_act, (2, 0, 1)),
# )
# ab = permute_final_dims(p, (1, 2, 0))
ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act)
ab = self.output_projection(self.layernorm2(ab))
2022-11-02 07:49:25 +00:00
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
2022-11-02 07:12:08 +00:00
return bias_ele_dropout_residual(ab,
self.output_bias,
g,
dropout_mask,
Z_raw,
prob=self.p_drop)
class TriangleAttentionStartingNode(nn.Module):
def __init__(self, d_pair, p_drop, c=32, n_head=4):
super(TriangleAttentionStartingNode, self).__init__()
self.d_pair = d_pair
self.c = c
self.n_head = n_head
self.p_drop = p_drop
self.layernorm1 = LayerNorm(d_pair)
_init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]),
std=1.0 / math.sqrt(d_pair))
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
self.attention = SelfAttention(qkv_dim=d_pair,
c=c,
n_head=n_head,
out_dim=d_pair,
gating=True,
last_bias_fuse=True)
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
def forward(self, Z_raw):
Z = self.layernorm1(Z_raw)
b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights)
Z = self.attention(Z, b)
2022-11-02 07:49:25 +00:00
dropout_mask = torch.ones_like(Z[:, 0:1, :, :]).to(Z.device).to(Z.dtype)
2022-11-02 07:12:08 +00:00
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
class TriangleAttentionEndingNode(nn.Module):
def __init__(self, d_pair, p_drop, c=32, n_head=4):
super(TriangleAttentionEndingNode, self).__init__()
self.d_pair = d_pair
self.c = c
self.n_head = n_head
self.p_drop = p_drop
self.layernorm1 = LayerNorm(d_pair)
_init_weights = torch.nn.init.normal_(torch.zeros([d_pair, n_head]),
std=1.0 / math.sqrt(d_pair))
self.linear_b_weights = nn.parameter.Parameter(data=_init_weights)
self.attention = SelfAttention(qkv_dim=d_pair,
c=c,
n_head=n_head,
out_dim=d_pair,
gating=True,
last_bias_fuse=True)
self.out_bias = nn.parameter.Parameter(data=torch.zeros((d_pair,)), requires_grad=True)
def forward(self, Z_raw):
Z = Z_raw.transpose(-2, -3)
Z = self.layernorm1(Z)
b = torch.einsum('bqkc,ch->bhqk', Z, self.linear_b_weights)
Z = self.attention(Z, b)
Z = Z.transpose(-2, -3)
2022-11-02 07:49:25 +00:00
dropout_mask = torch.ones_like(Z[:, :, 0:1, :]).to(Z.device).to(Z.dtype)
2022-11-02 07:12:08 +00:00
return bias_dropout_add(Z, self.out_bias, dropout_mask, Z_raw, prob=self.p_drop)
class PairStack(nn.Module):
def __init__(self, d_pair, p_drop=0.25):
super(PairStack, self).__init__()
self.TriangleMultiplicationOutgoing = TriangleMultiplicationOutgoing(d_pair, p_drop=p_drop)
self.TriangleMultiplicationIncoming = TriangleMultiplicationIncoming(d_pair, p_drop=p_drop)
self.TriangleAttentionStartingNode = TriangleAttentionStartingNode(d_pair, p_drop=p_drop)
self.TriangleAttentionEndingNode = TriangleAttentionEndingNode(d_pair, p_drop=p_drop)
self.PairTransition = Transition(d=d_pair)
def forward(self, pair):
pair = self.TriangleMultiplicationOutgoing(pair)
pair = self.TriangleMultiplicationIncoming(pair)
pair = self.TriangleAttentionStartingNode(pair)
pair = self.TriangleAttentionEndingNode(pair)
pair = self.PairTransition(pair)
return pair