Fit to flash attention 1.0

pull/396/head
Pryest 2023-10-09 20:46:17 +08:00
parent a35ce4c888
commit a3580acb6c
1 changed files with 6 additions and 1 deletions

View File

@ -8,7 +8,12 @@ from typing import Optional
import torch
import torch.nn.functional as F
from einops import rearrange
from flash_attn import flash_attn_unpadded_kvpacked_func
try:
from flash_attn import flash_attn_unpadded_kvpacked_func
except ImportError:
from flash_attn import flash_attn_varlen_kvpacked_func as flash_attn_unpadded_kvpacked_func
from flash_attn.modules.mha import (
CrossAttention,
FlashCrossAttention,