diff --git a/examples/language/palm/palm_pytorch/palm_pytorch.py b/examples/language/palm/palm_pytorch/palm_pytorch.py index 1509dd84e..105991967 100644 --- a/examples/language/palm/palm_pytorch/palm_pytorch.py +++ b/examples/language/palm/palm_pytorch/palm_pytorch.py @@ -1,7 +1,7 @@ import torch import torch.nn.functional as F from einops import rearrange -from torch import einsum, nn +from torch import einsum, nn, matmul # normalization # they use layernorm without bias, something that pytorch does not offer @@ -46,7 +46,8 @@ class RotaryEmbedding(nn.Module): def forward(self, max_seq_len, *, device): seq = torch.arange(max_seq_len, device=device) - freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq) + #freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq) + freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq) return torch.cat((freqs, freqs), dim=-1) @@ -139,6 +140,8 @@ class Attention(nn.Module): q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)) + + # split heads # they use multi-query single-key-value attention, yet another Noam Shazeer paper # they found no performance loss past a certain scale, and more efficient decoding obviously @@ -155,9 +158,13 @@ class Attention(nn.Module): q = q * self.scale + b, h, i, d, j = q.size(0), q.size(1), q.size(2), q.size(3), k.size(1) + # similarity - sim = einsum("b h i d, b j d -> b h i j", q, k) + #sim = einsum("b h i d, b j d -> b h i j", q, k) + sim = matmul(q.reshape(b, h*i, d), k.transpose(1,2)) + sim = sim.reshape(b, h, i, j) # causal mask @@ -169,9 +176,13 @@ class Attention(nn.Module): sim = sim - sim.amax(dim=-1, keepdim=True).detach() attn = sim.softmax(dim=-1) + b_, h_, i_, j_, d_ = attn.size(0), attn.size(1), attn.size(2), attn.size(3), v.size(2) + # aggregate values - out = einsum("b h i j, b j d -> b h i d", attn, v) + #out = einsum("b h i j, b j d -> b h i d", attn, v) + out = matmul(attn.reshape(b_, h_*i_, j_), v) + out = out.reshape(b_, h_, i_, d_) # merge heads