diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index ae4de68..6533dd7 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -10,9 +10,19 @@ import torch.nn.functional as F from einops import rearrange try: - from flash_attn import flash_attn_unpadded_kvpacked_func + from flash_attn import flash_attn_unpadded_func except ImportError: - from flash_attn import flash_attn_varlen_kvpacked_func as flash_attn_unpadded_kvpacked_func + try: + from flash_attn import ( + flash_attn_unpadded_kvpacked_func as flash_attn_unpadded_func, + ) + except ImportError: + try: + from flash_attn import ( + flash_attn_varlen_kvpacked_func as flash_attn_unpadded_func, + ) + except ImportError: + raise ImportError("Please check your flash_attn version >= 1.0.5.") from flash_attn.modules.mha import ( CrossAttention, @@ -285,7 +295,7 @@ class MHA(nn.Module): if total_kv.dtype not in [torch.float16, torch.bfloat16]: total_kv = total_kv.to(torch.bfloat16) - output = flash_attn_unpadded_kvpacked_func( + output = flash_attn_unpadded_func( total_q, total_kv, cu_seqlens, cu_seqlens, max_seqlen_q, max_seqlen_k, 0.0, None, True, False ).to(x.dtype)