mirror of https://github.com/InternLM/InternLM
Fit to flash attention 1.0
parent
78353e12cf
commit
a35ce4c888
|
@ -8,7 +8,7 @@ from typing import Optional
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from flash_attn import flash_attn_varlen_kvpacked_func
|
||||
from flash_attn import flash_attn_unpadded_kvpacked_func
|
||||
from flash_attn.modules.mha import (
|
||||
CrossAttention,
|
||||
FlashCrossAttention,
|
||||
|
@ -280,7 +280,7 @@ class MHA(nn.Module):
|
|||
if total_kv.dtype not in [torch.float16, torch.bfloat16]:
|
||||
total_kv = total_kv.to(torch.bfloat16)
|
||||
|
||||
output = flash_attn_varlen_kvpacked_func(
|
||||
output = flash_attn_unpadded_kvpacked_func(
|
||||
total_q, total_kv, cu_seqlens, cu_seqlens, max_seqlen_q, max_seqlen_k, 0.0, None, True, False
|
||||
).to(x.dtype)
|
||||
|
||||
|
@ -294,18 +294,17 @@ class MHA(nn.Module):
|
|||
k = k.squeeze(2)
|
||||
v = v.squeeze(2)
|
||||
sp = k.shape
|
||||
expansion = q.size(2) // k.size(2)
|
||||
scores = torch.einsum(
|
||||
"blhd,bnhd->bhln",
|
||||
q,
|
||||
k.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]),
|
||||
k.reshape(sp[0], sp[1], q.size(2), sp[3]),
|
||||
) / math.sqrt(q.size(-1))
|
||||
scores = scores.masked_fill(attn_mask, -65000.0)
|
||||
scores = F.softmax(scores, dim=-1) # bsz x h x L x L
|
||||
context = torch.einsum(
|
||||
"bhmn,bnhd->bmhd",
|
||||
scores,
|
||||
v.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]),
|
||||
v.reshape(sp[0], sp[1], q.size(2), sp[3]),
|
||||
)
|
||||
else:
|
||||
context = self.inner_cross_attn(q, kv, causal=True)
|
||||
|
|
Loading…
Reference in New Issue