import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from torch.nn import LayerNorm from .initializer import glorot_uniform_af from .kernel import bias_sigmod_ele class DropoutRowwise(nn.Module): def __init__(self, p): super(DropoutRowwise, self).__init__() self.p = p self.dropout = nn.Dropout(p=p) def forward(self, x): dropout_mask = torch.ones_like(x[:, 0:1, :, :]) dropout_mask = self.dropout(dropout_mask) return dropout_mask * x class DropoutColumnwise(nn.Module): def __init__(self, p): super(DropoutColumnwise, self).__init__() self.p = p self.dropout = nn.Dropout(p=p) def forward(self, x): dropout_mask = torch.ones_like(x[:, :, 0:1, :]) dropout_mask = self.dropout(dropout_mask) return dropout_mask * x class Transition(nn.Module): def __init__(self, d, n=4): super(Transition, self).__init__() self.norm = LayerNorm(d) self.linear1 = Linear(d, n * d, initializer='relu') self.linear2 = Linear(n * d, d, initializer='zeros') def forward(self, src): x = self.norm(src) x = self.linear2(F.relu(self.linear1(x))) return src + x class OutProductMean(nn.Module): def __init__(self, n_feat=64, n_feat_out=128, n_feat_proj=32): super(OutProductMean, self).__init__() self.layernormM = LayerNorm(n_feat) self.linear_a = Linear(n_feat, n_feat_proj) self.linear_b = Linear(n_feat, n_feat_proj) self.o_linear = Linear(n_feat_proj * n_feat_proj, n_feat_out, initializer='zero', use_bias=True) def forward(self, M): M = self.layernormM(M) left_act = self.linear_a(M) right_act = self.linear_b(M) o = torch.einsum('bsid,bsje->bijde', left_act, right_act).contiguous() # O = rearrange(O, 'b i j d e -> b i j (d e)') o = o.reshape(o.shape[0], o.shape[1], o.shape[2], -1) Z = self.o_linear(o) return Z class Linear(nn.Linear): """ A Linear layer with built-in nonstandard initializations. Called just like torch.nn.Linear. Implements the initializers in 1.11.4, plus some additional ones found in the code. """ def __init__( self, feature_in: int, feature_out: int, initializer: str = 'linear', use_bias: bool = True, bias_init: float = 0., ): super(Linear, self).__init__(feature_in, feature_out, bias=use_bias) self.use_bias = use_bias if initializer == 'linear': glorot_uniform_af(self.weight, gain=1.0) elif initializer == 'relu': glorot_uniform_af(self.weight, gain=2.0) elif initializer == 'zeros': nn.init.zeros_(self.weight) if self.use_bias: with torch.no_grad(): self.bias.fill_(bias_init) class SelfAttention(nn.Module): """ Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors """ def __init__(self, qkv_dim, c, n_head, out_dim, gating=True, last_bias_fuse=False): super(SelfAttention, self).__init__() self.qkv_dim = qkv_dim self.c = c self.n_head = n_head self.out_dim = out_dim self.gating = gating self.last_bias_fuse = last_bias_fuse self.scaling = self.c**(-0.5) # self.to_qkv = Linear(qkv_dim, 3 * n_head * c, initializer='linear') self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False) if gating: self.gating_bias = nn.parameter.Parameter(data=torch.ones((n_head * c,))) self.gating_linear = Linear(qkv_dim, n_head * c, initializer='zero', use_bias=False) self.o_linear = Linear(n_head * c, out_dim, initializer='zero', use_bias=(not last_bias_fuse)) def forward(self, in_data, nonbatched_bias=None): """ :param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim] :param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv] :param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv] """ # qkv = self.to_qkv(in_data).chunk(3, dim=-1) # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), qkv) q = self.to_q(in_data) k = self.to_k(in_data) v = self.to_v(in_data) # q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head), # [q, k, v]) q, k, v = map(lambda t: t.view(t.shape[0], t.shape[1], t.shape[2], self.n_head, -1).permute(0, 1, 3, 2, 4), [q, k, v]) q = q * self.scaling logits = torch.matmul(q, k.transpose(-1, -2)) if nonbatched_bias is not None: logits += nonbatched_bias.unsqueeze(1) weights = torch.softmax(logits, dim=-1) # weights = softmax(logits) weighted_avg = torch.matmul(weights, v) # weighted_avg = rearrange(weighted_avg, 'b1 b2 h n d -> b1 b2 n (h d)') weighted_avg = weighted_avg.permute(0, 1, 3, 2, 4) weighted_avg = weighted_avg.reshape(weighted_avg.shape[0], weighted_avg.shape[1], weighted_avg.shape[2], -1) if self.gating: gate_values = self.gating_linear(in_data) weighted_avg = bias_sigmod_ele(gate_values, self.gating_bias, weighted_avg) output = self.o_linear(weighted_avg) return output