diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index e4008e1..6017dbc 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -1,11 +1,29 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import math import warnings from typing import Optional import torch +import torch.nn.functional as F from einops import rearrange + +try: + from flash_attn.flash_attn_interface import flash_attn_unpadded_func +except ImportError: + try: + from flash_attn.flash_attn_interface import ( + flash_attn_unpadded_kvpacked_func as flash_attn_unpadded_func, + ) + except ImportError: + try: + from flash_attn.flash_attn_interface import ( + flash_attn_varlen_kvpacked_func as flash_attn_unpadded_func, + ) + except ImportError: + raise ImportError("Please check your flash_attn version >= 1.0.5.") + from flash_attn.modules.mha import ( CrossAttention, FlashCrossAttention, @@ -127,7 +145,7 @@ class MHA(nn.Module): else: return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs) - def _forward(self, x, seqlen=None, inference_params=None, **kwargs): + def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint: disable=W0613 """ Arguments: x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. @@ -135,6 +153,7 @@ class MHA(nn.Module): split x during sequence parallel, we split the batch * seqlen dimension (in case batch is small). """ + bsz, _, _ = x.shape qkv = self.Wqkv(x) if seqlen is None: qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) @@ -142,9 +161,8 @@ class MHA(nn.Module): qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim) if inference_params is None: - if self.rotary_emb_dim > 0: - kwargs["inference_params"] = inference_params - qkv = self.rotary_emb(qkv, **kwargs) + kwargs["inference_params"] = inference_params + qkv = self.rotary_emb(qkv, **kwargs) if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: with torch.cuda.amp.autocast(dtype=torch.bfloat16): if qkv.dtype not in [torch.float16, torch.bfloat16]: @@ -152,6 +170,7 @@ class MHA(nn.Module): context = self.inner_attn(qkv).to(x.dtype) else: context = self.inner_attn(qkv) + else: if self.use_dynamic_ntk_rope: q = qkv[:, :, 0] @@ -179,17 +198,131 @@ class MHA(nn.Module): q = qkv[:, :, 0] kv = qkv[:, :, 1:] else: - if self.rotary_emb_dim > 0: - kwargs["inference_params"] = inference_params - qkv = self.rotary_emb(qkv, **kwargs) - q = qkv[:, :, 0] assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" - kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx) + q, k, v = (x.squeeze(2) for x in qkv.chunk(chunks=3, dim=2)) + kv = torch.stack([k, v], dim=2) + assert self.rotary_emb_dim > 0, "You should use rotary_emb." - # If we're processing the prompt, causal=None (use self.causal). - # If we're decoding, then causal=False. - causal = None if inference_params.sequence_len_offset == 0 else False - context = self.inner_cross_attn(q, kv, causal=causal) + if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: + empties = inference_params.attention_mask[..., -1].sum(dim=-1) + if inference_params.sequence_len_offset == 0: + moved_q = q.clone() + moved_k = k.clone() + for i in range(len(empties)): + if empties[i] != 0: + moved_q[i][: -empties[i]] = q[i][empties[i] :] + moved_k[i][: -empties[i]] = k[i][empties[i] :] + moved_q = self.rotary_emb._single_eval_forward(moved_q, seqlen_offset=0) + moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0) + for i in range(len(empties)): + if empties[i] != 0: + q[i][empties[i] :] = moved_q[i][: -empties[i]] + k[i][empties[i] :] = moved_k[i][: -empties[i]] + else: + q[i] = moved_q[i] + k[i] = moved_k[i] + elif not self.use_dynamic_ntk_rope: + if inference_params.sequence_len_offset > self.max_position_embeddings: + warnings.warn( + "Notice your prompt's length is longer than model's max_position_embeddings: " + f"{self.max_position_embeddings}, may cause deviations in dynamic ntk calculations." + ) + q = q.squeeze(1) + k = k.squeeze(1) + q = self.rotary_emb._single_forward( + q, + inference_params.sequence_len_offset + * torch.ones(q.size(0), dtype=torch.int, device=q.device) + - empties, + ).unsqueeze(1) + k = self.rotary_emb._single_forward( + k, + inference_params.sequence_len_offset + * torch.ones(k.size(0), dtype=torch.int, device=k.device) + - empties, + ).unsqueeze(1) + else: + q = q.squeeze(1) + q = self.rotary_emb._single_forward( + q, + inference_params.sequence_len_offset + * torch.ones(q.size(0), dtype=torch.int, device=q.device) + - empties, + ).unsqueeze(1) + moved_k = k.clone() + for i in range(len(empties)): + if empties[i] != 0: + moved_k[i][: -empties[i]] = k[i][empties[i] :] + moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0) + for i in range(len(empties)): + if empties[i] != 0: + k[i][empties[i] :] = moved_k[i][: -empties[i]] + else: + k[i] = moved_k[i] + else: + q = self.rotary_emb._single_forward(q, inference_params.sequence_len_offset) + k = self.rotary_emb._single_forward(k, inference_params.sequence_len_offset) + + kv = torch.stack([k, v], dim=2) + kv = _update_kv_cache(kv, inference_params, self.layer_idx) + + if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: + if inference_params.sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen) + attn_mask = inference_params.attention_mask[:, None, ...] + attn_mask = torch.logical_or( + torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask + ) + attn_mask4flsh = ~attn_mask[:, :, -1, :].view(bsz, -1) + cu_seqlens = torch.concat( + [ + torch.tensor([0], dtype=torch.int32, device=attn_mask4flsh.device), + attn_mask4flsh.sum(dim=-1).to(dtype=torch.int32), + ], + dim=0, + ) + cu_seqlens = cu_seqlens.cumsum(dim=0, dtype=torch.int32) + max_seqlen_q = attn_mask4flsh.shape[-1] + max_seqlen_k = attn_mask4flsh.shape[-1] + total_q = q.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]) + total_kv = kv.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1, 1)).view( + -1, kv.shape[-3], kv.shape[-2], kv.shape[-1] + ) + + if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + if total_q.dtype not in [torch.float16, torch.bfloat16]: + total_q = total_q.to(torch.bfloat16) + if total_kv.dtype not in [torch.float16, torch.bfloat16]: + total_kv = total_kv.to(torch.bfloat16) + + output = flash_attn_unpadded_func( + total_q, total_kv, cu_seqlens, cu_seqlens, max_seqlen_q, max_seqlen_k, 0.0, None, True, False + ).to(x.dtype) + + context = torch.zeros_like(q) + context = context.masked_scatter_(attn_mask4flsh.view(bsz, -1, 1, 1), output) + + else: + attn_mask = inference_params.attention_mask[:, -1, :].view(bsz, 1, 1, -1) + + k, v = torch.chunk(kv, 2, dim=2) + k = k.squeeze(2) + v = v.squeeze(2) + sp = k.shape + scores = torch.einsum( + "blhd,bnhd->bhln", + q, + k.reshape(sp[0], sp[1], q.size(2), sp[3]), + ) / math.sqrt(q.size(-1)) + scores = scores.masked_fill(attn_mask, -65000.0) + scores = F.softmax(scores, dim=-1) # bsz x h x L x L + context = torch.einsum( + "bhmn,bnhd->bmhd", + scores, + v.reshape(sp[0], sp[1], q.size(2), sp[3]), + ) + else: + context = self.inner_cross_attn(q, kv, causal=True) if seqlen is None: context = rearrange(context, "b s h d -> b s (h d)")