mirror of https://github.com/hpcaitech/ColossalAI
[examples] replace einsum with matmul (#2210)
parent
7675792100
commit
92de90dfb3
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from torch import einsum, nn
|
from torch import einsum, nn, matmul
|
||||||
|
|
||||||
# normalization
|
# normalization
|
||||||
# they use layernorm without bias, something that pytorch does not offer
|
# they use layernorm without bias, something that pytorch does not offer
|
||||||
|
@ -46,7 +46,8 @@ class RotaryEmbedding(nn.Module):
|
||||||
|
|
||||||
def forward(self, max_seq_len, *, device):
|
def forward(self, max_seq_len, *, device):
|
||||||
seq = torch.arange(max_seq_len, device=device)
|
seq = torch.arange(max_seq_len, device=device)
|
||||||
freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq)
|
#freqs = einsum("i , j -> i j", seq.type_as(self.inv_freq), self.inv_freq)
|
||||||
|
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
|
||||||
return torch.cat((freqs, freqs), dim=-1)
|
return torch.cat((freqs, freqs), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
@ -139,6 +140,8 @@ class Attention(nn.Module):
|
||||||
|
|
||||||
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
|
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim=-1))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# split heads
|
# split heads
|
||||||
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
|
# they use multi-query single-key-value attention, yet another Noam Shazeer paper
|
||||||
# they found no performance loss past a certain scale, and more efficient decoding obviously
|
# they found no performance loss past a certain scale, and more efficient decoding obviously
|
||||||
|
@ -155,9 +158,13 @@ class Attention(nn.Module):
|
||||||
|
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
|
||||||
|
b, h, i, d, j = q.size(0), q.size(1), q.size(2), q.size(3), k.size(1)
|
||||||
|
|
||||||
# similarity
|
# similarity
|
||||||
|
|
||||||
sim = einsum("b h i d, b j d -> b h i j", q, k)
|
#sim = einsum("b h i d, b j d -> b h i j", q, k)
|
||||||
|
sim = matmul(q.reshape(b, h*i, d), k.transpose(1,2))
|
||||||
|
sim = sim.reshape(b, h, i, j)
|
||||||
|
|
||||||
# causal mask
|
# causal mask
|
||||||
|
|
||||||
|
@ -169,9 +176,13 @@ class Attention(nn.Module):
|
||||||
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
||||||
attn = sim.softmax(dim=-1)
|
attn = sim.softmax(dim=-1)
|
||||||
|
|
||||||
|
b_, h_, i_, j_, d_ = attn.size(0), attn.size(1), attn.size(2), attn.size(3), v.size(2)
|
||||||
|
|
||||||
# aggregate values
|
# aggregate values
|
||||||
|
|
||||||
out = einsum("b h i j, b j d -> b h i d", attn, v)
|
#out = einsum("b h i j, b j d -> b h i d", attn, v)
|
||||||
|
out = matmul(attn.reshape(b_, h_*i_, j_), v)
|
||||||
|
out = out.reshape(b_, h_, i_, d_)
|
||||||
|
|
||||||
# merge heads
|
# merge heads
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue