From a35ce4c888d3f3982649c569546c8cfd558134cf Mon Sep 17 00:00:00 2001 From: Pryest <495945214@qq.com> Date: Mon, 9 Oct 2023 20:43:21 +0800 Subject: [PATCH] Fit to flash attention 1.0 --- internlm/model/multi_head_attention.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 6c611d8..608b281 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -8,7 +8,7 @@ from typing import Optional import torch import torch.nn.functional as F from einops import rearrange -from flash_attn import flash_attn_varlen_kvpacked_func +from flash_attn import flash_attn_unpadded_kvpacked_func from flash_attn.modules.mha import ( CrossAttention, FlashCrossAttention, @@ -280,7 +280,7 @@ class MHA(nn.Module): if total_kv.dtype not in [torch.float16, torch.bfloat16]: total_kv = total_kv.to(torch.bfloat16) - output = flash_attn_varlen_kvpacked_func( + output = flash_attn_unpadded_kvpacked_func( total_q, total_kv, cu_seqlens, cu_seqlens, max_seqlen_q, max_seqlen_k, 0.0, None, True, False ).to(x.dtype) @@ -294,18 +294,17 @@ class MHA(nn.Module): k = k.squeeze(2) v = v.squeeze(2) sp = k.shape - expansion = q.size(2) // k.size(2) scores = torch.einsum( "blhd,bnhd->bhln", q, - k.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), + k.reshape(sp[0], sp[1], q.size(2), sp[3]), ) / math.sqrt(q.size(-1)) scores = scores.masked_fill(attn_mask, -65000.0) scores = F.softmax(scores, dim=-1) # bsz x h x L x L context = torch.einsum( "bhmn,bnhd->bmhd", scores, - v.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), + v.reshape(sp[0], sp[1], q.size(2), sp[3]), ) else: context = self.inner_cross_attn(q, kv, causal=True)