From a3580acb6c9ec7294d5b01fb2c77678db2d9a18b Mon Sep 17 00:00:00 2001 From: Pryest <495945214@qq.com> Date: Mon, 9 Oct 2023 20:46:17 +0800 Subject: [PATCH] Fit to flash attention 1.0 --- internlm/model/multi_head_attention.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 608b281..ae4de68 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -8,7 +8,12 @@ from typing import Optional import torch import torch.nn.functional as F from einops import rearrange -from flash_attn import flash_attn_unpadded_kvpacked_func + +try: + from flash_attn import flash_attn_unpadded_kvpacked_func +except ImportError: + from flash_attn import flash_attn_varlen_kvpacked_func as flash_attn_unpadded_kvpacked_func + from flash_attn.modules.mha import ( CrossAttention, FlashCrossAttention,