ColossalAI/tests/test_autochunk/evoformer/ops.py

177 lines
5.9 KiB
Python
Executable File

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch.nn import LayerNorm
from .initializer import glorot_uniform_af
from .kernel import bias_sigmod_ele
class DropoutRowwise(nn.Module):
def __init__(self, p):
super(DropoutRowwise, self).__init__()
self.p = p
self.dropout = nn.Dropout(p=p)
def forward(self, x):
dropout_mask = torch.ones_like(x[:, 0:1, :, :])
dropout_mask = self.dropout(dropout_mask)
return dropout_mask * x
class DropoutColumnwise(nn.Module):
def __init__(self, p):
super(DropoutColumnwise, self).__init__()
self.p = p
self.dropout = nn.Dropout(p=p)
def forward(self, x):
dropout_mask = torch.ones_like(x[:, :, 0:1, :])
dropout_mask = self.dropout(dropout_mask)
return dropout_mask * x
class Transition(nn.Module):
def __init__(self, d, n=4):
super(Transition, self).__init__()
self.norm = LayerNorm(d)
self.linear1 = Linear(d, n * d, initializer='relu')
self.linear2 = Linear(n * d, d, initializer='zeros')
def forward(self, src):
x = self.norm(src)
x = self.linear2(F.relu(self.linear1(x)))
return src + x
class OutProductMean(nn.Module):
def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32):
super(OutProductMean, self).__init__()
self.layernormM = LayerNorm(n_feat)
self.linear_a = Linear(n_feat, n_feat_proj)
self.linear_b = Linear(n_feat, n_feat_proj)
self.o_linear = Linear(n_feat_proj * n_feat_proj,
n_feat_out,
initializer='zero',
use_bias=True)
def forward(self, M):
M = self.layernormM(M)
left_act = self.linear_a(M)
right_act = self.linear_b(M)
o = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous()
# O = rearrange(O, 'b i j d e -> b i j (d e)')
o = o.reshape(o.shape[0], o.shape[1], o.shape[2], -1)
Z = self.o_linear(o)
return Z
class Linear(nn.Linear):
"""
A Linear layer with built-in nonstandard initializations. Called just
like torch.nn.Linear.
Implements the initializers in 1.11.4, plus some additional ones found
in the code.
"""
def __init__(
self,
feature_in: int,
feature_out: int,
initializer: str = 'linear',
use_bias: bool = True,
bias_init: float = 0.,
):
super(Linear, self).__init__(feature_in, feature_out, bias=use_bias)
self.use_bias = use_bias
if initializer == 'linear':
glorot_uniform_af(self.weight, gain=1.0)
elif initializer == 'relu':
glorot_uniform_af(self.weight, gain=2.0)
elif initializer == 'zeros':
nn.init.zeros_(self.weight)
if self.use_bias:
with torch.no_grad():
self.bias.fill_(bias_init)
class SelfAttention(nn.Module):
"""
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
"""
def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False):
super(SelfAttention, self).__init__()
self.qkv_dim = qkv_dim
self.c = c
self.n_head = n_head
self.out_dim = out_dim
self.gating = gating
self.last_bias_fuse = last_bias_fuse
self.scaling = self.c**(-0.5)
# self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear')
self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
if gating:
self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,)))
self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False)
self.o_linear = Linear(n_head * c,
out_dim,
initializer='zero',
use_bias=(not last_bias_fuse))
def forward(self, in_data, nonbatched_bias=None):
"""
:param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim]
:param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv]
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
"""
# qkv = self.to_qkv(in_data).chunk(3, dim=-1)
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv)
q = self.to_q(in_data)
k = self.to_k(in_data)
v = self.to_v(in_data)
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head),
# [q, k, v])
q, k, v = map(lambda t: t.view(t.shape[0], t.shape[1], t.shape[2], self.n_head, -1).permute(0, 1, 3, 2, 4),
[q, k, v])
q = q * self.scaling
logits = torch.matmul(q, k.transpose(-1, -2))
if nonbatched_bias is not None:
logits += nonbatched_bias.unsqueeze(1)
weights = torch.softmax(logits, dim=-1)
# weights = softmax(logits)
weighted_avg = torch.matmul(weights, v)
# weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)')
weighted_avg = weighted_avg.permute(0, 1, 3, 2, 4)
weighted_avg = weighted_avg.reshape(weighted_avg.shape[0], weighted_avg.shape[1], weighted_avg.shape[2], -1)
if self.gating:
gate_values = self.gating_linear(in_data)
weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg)
output = self.o_linear(weighted_avg)
return output