Fit to flash attention 1.0.5.

pull/396/head
Pryest 2023-10-09 21:15:40 +08:00
parent b38ba5dad2
commit 66eba48c9f
1 changed files with 3 additions and 3 deletions

View File

@ -10,15 +10,15 @@ import torch.nn.functional as F
from einops import rearrange
try:
from flash_attn import flash_attn_unpadded_func
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
except ImportError:
try:
from flash_attn import (
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_kvpacked_func as flash_attn_unpadded_func,
)
except ImportError:
try:
from flash_attn import (
from flash_attn.flash_attn_interface import (
flash_attn_varlen_kvpacked_func as flash_attn_unpadded_func,
)
except ImportError: