mirror of https://github.com/hpcaitech/ColossalAI
[examples] replace einsum with matmul (#2210)
parent
7675792100
commit
92de90dfb3
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from torch import einsum, nn
|
||||
from torch import einsum, nn, matmul
|
||||
|
||||
# normalization
|
||||
# they use layernorm without bias, something that pytorch does not offer
|
||||
|
@ -46,7 +46,8 @@ class RotaryEmbedding(nn.Module):
|
|||
|
||||
def forward(self, max_seq_len, *, device):
|
||||
seq = torch.arange(max_seq_len, device=device)
|
||||
freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq)
|
||||
#freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq)
|
||||
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
|
||||
return torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
|
||||
|
@ -139,6 +140,8 @@ class Attention(nn.Module):
|
|||
|
||||
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
|
||||
|
||||
|
||||
|
||||
# split heads
|
||||
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
|
||||
# they found no performance loss past a certain scale, and more efficient decoding obviously
|
||||
|
@ -155,9 +158,13 @@ class Attention(nn.Module):
|
|||
|
||||
q = q * self.scale
|
||||
|
||||
b, h, i, d, j = q.size(0), q.size(1), q.size(2), q.size(3), k.size(1)
|
||||
|
||||
# similarity
|
||||
|
||||
sim = einsum("b h i d, b j d -> b h i j", q, k)
|
||||
#sim = einsum("b h i d, b j d -> b h i j", q, k)
|
||||
sim = matmul(q.reshape(b, h*i, d), k.transpose(1,2))
|
||||
sim = sim.reshape(b, h, i, j)
|
||||
|
||||
# causal mask
|
||||
|
||||
|
@ -169,9 +176,13 @@ class Attention(nn.Module):
|
|||
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
b_, h_, i_, j_, d_ = attn.size(0), attn.size(1), attn.size(2), attn.size(3), v.size(2)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = einsum("b h i j, b j d -> b h i d", attn, v)
|
||||
#out = einsum("b h i j, b j d -> b h i d", attn, v)
|
||||
out = matmul(attn.reshape(b_, h_*i_, j_), v)
|
||||
out = out.reshape(b_, h_, i_, d_)
|
||||
|
||||
# merge heads
|
||||
|
||||
|
|
Loading…
Reference in New Issue