Fit to flash attention 1.0

pull/396/head
Pryest 2023-10-09 20:43:21 +08:00
parent 78353e12cf
commit a35ce4c888
1 changed files with 4 additions and 5 deletions

View File

@ -8,7 +8,7 @@ from typing import Optional
import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn import flash_attn_varlen_kvpacked_func
from flash_attn import flash_attn_unpadded_kvpacked_func
from flash_attn.modules.mha import (
CrossAttention,
FlashCrossAttention,
@ -280,7 +280,7 @@ class MHA(nn.Module):
if total_kv.dtype not in [torch.float16, torch.bfloat16]:
total_kv = total_kv.to(torch.bfloat16)
output = flash_attn_varlen_kvpacked_func(
output = flash_attn_unpadded_kvpacked_func(
total_q, total_kv, cu_seqlens, cu_seqlens, max_seqlen_q, max_seqlen_k, 0.0, None, True, False
).to(x.dtype)
@ -294,18 +294,17 @@ class MHA(nn.Module):
k = k.squeeze(2)
v = v.squeeze(2)
sp = k.shape
expansion = q.size(2) // k.size(2)
scores = torch.einsum(
"blhd,bnhd->bhln",
q,
k.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]),
k.reshape(sp[0], sp[1], q.size(2), sp[3]),
) / math.sqrt(q.size(-1))
scores = scores.masked_fill(attn_mask, -65000.0)
scores = F.softmax(scores, dim=-1) # bsz x h x L x L
context = torch.einsum(
"bhmn,bnhd->bmhd",
scores,
v.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]),
v.reshape(sp[0], sp[1], q.size(2), sp[3]),
)
else:
context = self.inner_cross_attn(q, kv, causal=True)