2022-12-22 02:15:34 +00:00
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
|
|
|
from einops import rearrange
|
2023-09-19 06:20:26 +00:00
|
|
|
from torch import matmul, nn
|
2022-12-22 02:15:34 +00:00
|
|
|
|
|
|
|
# normalization
|
|
|
|
# they use layernorm without bias, something that pytorch does not offer
|
|
|
|
|
|
|
|
|
|
|
|
class LayerNorm(nn.Module):
|
|
|
|
def __init__(self, dim, eps=1e-5):
|
|
|
|
super().__init__()
|
|
|
|
self.eps = eps
|
|
|
|
self.gamma = nn.Parameter(torch.ones(dim))
|
|
|
|
self.register_buffer("beta", torch.zeros(dim))
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
|
|
|
|
|
|
|
|
|
|
|
|
# parallel with residual
|
|
|
|
# discovered by Wang et al + EleutherAI from GPT-J fame
|
|
|
|
|
|
|
|
|
|
|
|
class ParallelResidual(nn.Module):
|
|
|
|
def __init__(self, *fns):
|
|
|
|
super().__init__()
|
|
|
|
self.fns = nn.ModuleList(fns)
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
return x + sum([fn(x) for fn in self.fns])
|
|
|
|
|
|
|
|
|
|
|
|
# rotary positional embedding
|
|
|
|
# https://arxiv.org/abs/2104.09864
|
|
|
|
|
|
|
|
|
|
|
|
class RotaryEmbedding(nn.Module):
|
|
|
|
def __init__(self, dim):
|
|
|
|
super().__init__()
|
2023-09-19 06:20:26 +00:00
|
|
|
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
2022-12-22 02:15:34 +00:00
|
|
|
self.register_buffer("inv_freq", inv_freq)
|
|
|
|
|
|
|
|
def forward(self, max_seq_len, *, device):
|
|
|
|
seq = torch.arange(max_seq_len, device=device)
|
2023-09-19 06:20:26 +00:00
|
|
|
# freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq)
|
|
|
|
# freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
|
2022-12-29 06:01:09 +00:00
|
|
|
i, j = len(seq.type_as(self.inv_freq)), len(self.inv_freq)
|
|
|
|
freqs = matmul(seq.type_as(self.inv_freq).reshape(i, 1), self.inv_freq.reshape(1, j))
|
2022-12-22 02:15:34 +00:00
|
|
|
return torch.cat((freqs, freqs), dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
def rotate_half(x):
|
|
|
|
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
|
|
|
x1, x2 = x.unbind(dim=-2)
|
|
|
|
return torch.cat((-x2, x1), dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
def apply_rotary_pos_emb(pos, t):
|
|
|
|
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
|
|
|
|
|
|
|
|
|
|
|
|
# feedforward
|
|
|
|
# classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU
|
|
|
|
# https://arxiv.org/abs/2002.05202
|
|
|
|
|
|
|
|
|
|
|
|
class SwiGLU(nn.Module):
|
|
|
|
def forward(self, x):
|
|
|
|
x, gate = x.chunk(2, dim=-1)
|
|
|
|
return F.silu(gate) * x
|
|
|
|
|
|
|
|
|
|
|
|
def FeedForward(dim, mult=4):
|
|
|
|
inner_dim = int(dim * mult)
|
|
|
|
return nn.Sequential(
|
|
|
|
LayerNorm(dim),
|
|
|
|
nn.Linear(dim, inner_dim * 2, bias=False),
|
|
|
|
SwiGLU(),
|
|
|
|
nn.Linear(inner_dim, dim, bias=False),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# attention
|
|
|
|
class Attention(nn.Module):
|
|
|
|
def __init__(self, dim, dim_head=64, heads=8):
|
|
|
|
super().__init__()
|
|
|
|
inner_dim = dim_head * heads
|
|
|
|
self.norm = LayerNorm(dim)
|
|
|
|
self.heads = heads
|
|
|
|
self.scale = dim_head**-0.5
|
|
|
|
self.rotary_emb = RotaryEmbedding(dim_head)
|
|
|
|
|
|
|
|
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
|
|
|
self.to_kv = nn.Linear(dim, dim_head * 2, bias=False)
|
|
|
|
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
|
|
|
|
|
|
|
# for caching causal mask and rotary embeddings
|
|
|
|
|
|
|
|
self.register_buffer("mask", None, persistent=False)
|
|
|
|
self.register_buffer("pos_emb", None, persistent=False)
|
|
|
|
|
|
|
|
def get_mask(self, n, device):
|
|
|
|
if self.mask is not None and self.mask.shape[-1] >= n:
|
|
|
|
return self.mask[:n, :n]
|
|
|
|
|
|
|
|
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
|
|
|
|
self.register_buffer("mask", mask, persistent=False)
|
|
|
|
return mask
|
|
|
|
|
|
|
|
def get_rotary_embedding(self, n, device):
|
|
|
|
if self.pos_emb is not None and self.pos_emb.shape[-2] >= n:
|
|
|
|
return self.pos_emb[:n]
|
|
|
|
|
|
|
|
pos_emb = self.rotary_emb(n, device=device)
|
|
|
|
self.register_buffer("position", pos_emb, persistent=False)
|
|
|
|
return pos_emb
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
"""
|
|
|
|
einstein notation
|
|
|
|
b - batch
|
|
|
|
h - heads
|
|
|
|
n, i, j - sequence length (base sequence length, source, target)
|
|
|
|
d - feature dimension
|
|
|
|
"""
|
|
|
|
|
|
|
|
n, device, h = x.shape[1], x.device, self.heads
|
|
|
|
|
|
|
|
# pre layernorm
|
|
|
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
|
|
|
# queries, keys, values
|
|
|
|
|
|
|
|
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
|
|
|
|
|
|
|
|
# split heads
|
|
|
|
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
|
|
|
|
# they found no performance loss past a certain scale, and more efficient decoding obviously
|
|
|
|
# https://arxiv.org/abs/1911.02150
|
|
|
|
|
|
|
|
q = rearrange(q, "b n (h d) -> b h n d", h=h)
|
|
|
|
|
|
|
|
# rotary embeddings
|
|
|
|
|
|
|
|
positions = self.get_rotary_embedding(n, device)
|
|
|
|
q, k = map(lambda t: apply_rotary_pos_emb(positions, t), (q, k))
|
|
|
|
|
|
|
|
# scale
|
|
|
|
|
|
|
|
q = q * self.scale
|
|
|
|
|
2022-12-28 11:03:06 +00:00
|
|
|
b, h, i, d, j = q.size(0), q.size(1), q.size(2), q.size(3), k.size(1)
|
|
|
|
|
2022-12-22 02:15:34 +00:00
|
|
|
# similarity
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
# sim = einsum("b h i d, b j d -> b h i j", q, k)
|
2022-12-29 06:28:31 +00:00
|
|
|
sim = matmul(q.reshape(b, h * i, d), k.transpose(1, 2))
|
2022-12-28 11:03:06 +00:00
|
|
|
sim = sim.reshape(b, h, i, j)
|
2022-12-22 02:15:34 +00:00
|
|
|
|
|
|
|
# causal mask
|
|
|
|
|
|
|
|
causal_mask = self.get_mask(n, device)
|
|
|
|
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
|
|
|
|
|
|
|
|
# attention
|
|
|
|
|
|
|
|
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
|
|
|
attn = sim.softmax(dim=-1)
|
|
|
|
|
2022-12-28 11:03:06 +00:00
|
|
|
b_, h_, i_, j_, d_ = attn.size(0), attn.size(1), attn.size(2), attn.size(3), v.size(2)
|
|
|
|
|
2022-12-22 02:15:34 +00:00
|
|
|
# aggregate values
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
# out = einsum("b h i j, b j d -> b h i d", attn, v)
|
2022-12-29 06:28:31 +00:00
|
|
|
out = matmul(attn.reshape(b_, h_ * i_, j_), v)
|
2022-12-28 11:03:06 +00:00
|
|
|
out = out.reshape(b_, h_, i_, d_)
|
2022-12-22 02:15:34 +00:00
|
|
|
|
|
|
|
# merge heads
|
|
|
|
|
|
|
|
out = rearrange(out, "b h n d -> b n (h d)")
|
|
|
|
return self.to_out(out)
|
|
|
|
|
|
|
|
|
|
|
|
# transformer
|
|
|
|
|
|
|
|
|
|
|
|
def PaLM(*, dim, num_tokens, depth, dim_head=64, heads=8, ff_mult=4):
|
|
|
|
net = nn.Sequential(
|
2023-09-19 06:20:26 +00:00
|
|
|
nn.Embedding(num_tokens, dim),
|
|
|
|
*[
|
2022-12-22 02:15:34 +00:00
|
|
|
ParallelResidual(
|
|
|
|
Attention(dim=dim, dim_head=dim_head, heads=heads),
|
|
|
|
FeedForward(dim=dim, mult=ff_mult),
|
2023-09-19 06:20:26 +00:00
|
|
|
)
|
|
|
|
for _ in range(depth)
|
|
|
|
],
|
|
|
|
LayerNorm(dim),
|
|
|
|
nn.Linear(dim, num_tokens, bias=False),
|
|
|
|
)
|
2022-12-22 02:15:34 +00:00
|
|
|
|
|
|
|
# they used embedding weight tied projection out to logits, not common, but works
|
|
|
|
net[-1].weight = net[0].weight
|
|
|
|
|
|
|
|
nn.init.normal_(net[0].weight, std=0.02)
|
|
|
|
return net
|