diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 608b281..ae4de68 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -8,7 +8,12 @@ from typing import Optional import torch import torch.nn.functional as F from einops import rearrange -from flash_attn import flash_attn_unpadded_kvpacked_func + +try: + from flash_attn import flash_attn_unpadded_kvpacked_func +except ImportError: + from flash_attn import flash_attn_varlen_kvpacked_func as flash_attn_unpadded_kvpacked_func + from flash_attn.modules.mha import ( CrossAttention, FlashCrossAttention,