[examples] replace einsum with matmul (#2210)

pull/2226/head
ZijianYY 2022-12-28 19:03:06 +08:00 committed by GitHub
parent 7675792100
commit 92de90dfb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 4 deletions

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from einops import rearrange from einops import rearrange
from torch import einsum, nn from torch import einsum, nn, matmul
# normalization # normalization
# they use layernorm without bias, something that pytorch does not offer # they use layernorm without bias, something that pytorch does not offer
@ -46,7 +46,8 @@ class RotaryEmbedding(nn.Module):
def forward(self, max_seq_len, *, device): def forward(self, max_seq_len, *, device):
seq = torch.arange(max_seq_len, device=device) seq = torch.arange(max_seq_len, device=device)
freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq) #freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq)
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
return torch.cat((freqs, freqs), dim=-1) return torch.cat((freqs, freqs), dim=-1)
@ -139,6 +140,8 @@ class Attention(nn.Module):
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1)) q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
# split heads # split heads
# they use multi-query single-key-value attention, yet another Noam Shazeer paper # they use multi-query single-key-value attention, yet another Noam Shazeer paper
# they found no performance loss past a certain scale, and more efficient decoding obviously # they found no performance loss past a certain scale, and more efficient decoding obviously
@ -155,9 +158,13 @@ class Attention(nn.Module):
q = q * self.scale q = q * self.scale
b, h, i, d, j = q.size(0), q.size(1), q.size(2), q.size(3), k.size(1)
# similarity # similarity
sim = einsum("b h i d, b j d -> b h i j", q, k) #sim = einsum("b h i d, b j d -> b h i j", q, k)
sim = matmul(q.reshape(b, h*i, d), k.transpose(1,2))
sim = sim.reshape(b, h, i, j)
# causal mask # causal mask
@ -169,9 +176,13 @@ class Attention(nn.Module):
sim = sim - sim.amax(dim=-1, keepdim=True).detach() sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1) attn = sim.softmax(dim=-1)
b_, h_, i_, j_, d_ = attn.size(0), attn.size(1), attn.size(2), attn.size(3), v.size(2)
# aggregate values # aggregate values
out = einsum("b h i j, b j d -> b h i d", attn, v) #out = einsum("b h i j, b j d -> b h i d", attn, v)
out = matmul(attn.reshape(b_, h_*i_, j_), v)
out = out.reshape(b_, h_, i_, d_)
# merge heads # merge heads