From b38ba5dad2c73eb511b59cf2fecbb1d3d6cccc5b Mon Sep 17 00:00:00 2001 From: Pryest <495945214@qq.com> Date: Mon, 9 Oct 2023 21:03:16 +0800 Subject: [PATCH] Fit to flash attention 1.0.5. --- internlm/model/multi_head_attention.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index ae4de68..6533dd7 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -10,9 +10,19 @@ import torch.nn.functional as F from einops import rearrange try: - from flash_attn import flash_attn_unpadded_kvpacked_func + from flash_attn import flash_attn_unpadded_func except ImportError: - from flash_attn import flash_attn_varlen_kvpacked_func as flash_attn_unpadded_kvpacked_func + try: + from flash_attn import ( + flash_attn_unpadded_kvpacked_func as flash_attn_unpadded_func, + ) + except ImportError: + try: + from flash_attn import ( + flash_attn_varlen_kvpacked_func as flash_attn_unpadded_func, + ) + except ImportError: + raise ImportError("Please check your flash_attn version >= 1.0.5.") from flash_attn.modules.mha import ( CrossAttention, @@ -285,7 +295,7 @@ class MHA(nn.Module): if total_kv.dtype not in [torch.float16, torch.bfloat16]: total_kv = total_kv.to(torch.bfloat16) - output = flash_attn_unpadded_kvpacked_func( + output = flash_attn_unpadded_func( total_q, total_kv, cu_seqlens, cu_seqlens, max_seqlen_q, max_seqlen_k, 0.0, None, True, False ).to(x.dtype)