Fit to flash attention 1.0.5.

pull/396/head
Pryest 2023-10-09 21:03:16 +08:00
parent a3580acb6c
commit b38ba5dad2
1 changed files with 13 additions and 3 deletions

View File

@ -10,9 +10,19 @@ import torch.nn.functional as F
from einops import rearrange
try:
from flash_attn import flash_attn_unpadded_kvpacked_func
from flash_attn import flash_attn_unpadded_func
except ImportError:
from flash_attn import flash_attn_varlen_kvpacked_func as flash_attn_unpadded_kvpacked_func
try:
from flash_attn import (
flash_attn_unpadded_kvpacked_func as flash_attn_unpadded_func,
)
except ImportError:
try:
from flash_attn import (
flash_attn_varlen_kvpacked_func as flash_attn_unpadded_func,
)
except ImportError:
raise ImportError("Please check your flash_attn version >= 1.0.5.")
from flash_attn.modules.mha import (
CrossAttention,
@ -285,7 +295,7 @@ class MHA(nn.Module):
if total_kv.dtype not in [torch.float16, torch.bfloat16]:
total_kv = total_kv.to(torch.bfloat16)
output = flash_attn_unpadded_kvpacked_func(
output = flash_attn_unpadded_func(
total_q, total_kv, cu_seqlens, cu_seqlens, max_seqlen_q, max_seqlen_k, 0.0, None, True, False
).to(x.dtype)