|
|
|
@ -1,7 +1,7 @@
|
|
|
|
|
import torch |
|
|
|
|
import torch.nn.functional as F |
|
|
|
|
from einops import rearrange |
|
|
|
|
from torch import einsum, nn |
|
|
|
|
from torch import einsum, nn, matmul |
|
|
|
|
|
|
|
|
|
# normalization |
|
|
|
|
# they use layernorm without bias, something that pytorch does not offer |
|
|
|
@ -46,7 +46,8 @@ class RotaryEmbedding(nn.Module):
|
|
|
|
|
|
|
|
|
|
def forward(self, max_seq_len, *, device): |
|
|
|
|
seq = torch.arange(max_seq_len, device=device) |
|
|
|
|
freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq) |
|
|
|
|
#freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq) |
|
|
|
|
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) |
|
|
|
|
return torch.cat((freqs, freqs), dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -139,6 +140,8 @@ class Attention(nn.Module):
|
|
|
|
|
|
|
|
|
|
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# split heads |
|
|
|
|
# they use multi-query single-key-value attention, yet another Noam Shazeer paper |
|
|
|
|
# they found no performance loss past a certain scale, and more efficient decoding obviously |
|
|
|
@ -155,9 +158,13 @@ class Attention(nn.Module):
|
|
|
|
|
|
|
|
|
|
q = q * self.scale |
|
|
|
|
|
|
|
|
|
b, h, i, d, j = q.size(0), q.size(1), q.size(2), q.size(3), k.size(1) |
|
|
|
|
|
|
|
|
|
# similarity |
|
|
|
|
|
|
|
|
|
sim = einsum("b h i d, b j d -> b h i j", q, k) |
|
|
|
|
#sim = einsum("b h i d, b j d -> b h i j", q, k) |
|
|
|
|
sim = matmul(q.reshape(b, h*i, d), k.transpose(1,2)) |
|
|
|
|
sim = sim.reshape(b, h, i, j) |
|
|
|
|
|
|
|
|
|
# causal mask |
|
|
|
|
|
|
|
|
@ -169,9 +176,13 @@ class Attention(nn.Module):
|
|
|
|
|
sim = sim - sim.amax(dim=-1, keepdim=True).detach() |
|
|
|
|
attn = sim.softmax(dim=-1) |
|
|
|
|
|
|
|
|
|
b_, h_, i_, j_, d_ = attn.size(0), attn.size(1), attn.size(2), attn.size(3), v.size(2) |
|
|
|
|
|
|
|
|
|
# aggregate values |
|
|
|
|
|
|
|
|
|
out = einsum("b h i j, b j d -> b h i d", attn, v) |
|
|
|
|
#out = einsum("b h i j, b j d -> b h i d", attn, v) |
|
|
|
|
out = matmul(attn.reshape(b_, h_*i_, j_), v) |
|
|
|
|
out = out.reshape(b_, h_, i_, d_) |
|
|
|
|
|
|
|
|
|
# merge heads |
|
|
|
|
|
|
|
|
|