diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 6533dd7..6017dbc 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -10,15 +10,15 @@ import torch.nn.functional as F from einops import rearrange try: - from flash_attn import flash_attn_unpadded_func + from flash_attn.flash_attn_interface import flash_attn_unpadded_func except ImportError: try: - from flash_attn import ( + from flash_attn.flash_attn_interface import ( flash_attn_unpadded_kvpacked_func as flash_attn_unpadded_func, ) except ImportError: try: - from flash_attn import ( + from flash_attn.flash_attn_interface import ( flash_attn_varlen_kvpacked_func as flash_attn_unpadded_func, ) except ImportError: