From 66eba48c9f66f9aa6e1f17220c7b9d2a919def57 Mon Sep 17 00:00:00 2001 From: Pryest <495945214@qq.com> Date: Mon, 9 Oct 2023 21:15:40 +0800 Subject: [PATCH] Fit to flash attention 1.0.5. --- internlm/model/multi_head_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 6533dd7..6017dbc 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -10,15 +10,15 @@ import torch.nn.functional as F from einops import rearrange try: - from flash_attn import flash_attn_unpadded_func + from flash_attn.flash_attn_interface import flash_attn_unpadded_func except ImportError: try: - from flash_attn import ( + from flash_attn.flash_attn_interface import ( flash_attn_unpadded_kvpacked_func as flash_attn_unpadded_func, ) except ImportError: try: - from flash_attn import ( + from flash_attn.flash_attn_interface import ( flash_attn_varlen_kvpacked_func as flash_attn_unpadded_func, ) except ImportError: