From e5bf40e38f24ec2635829594bbc5c1edae74a07f Mon Sep 17 00:00:00 2001 From: Pryest <495945214@qq.com> Date: Sat, 7 Oct 2023 20:13:30 +0800 Subject: [PATCH] Fix errant inference_forward. --- internlm/model/multi_head_attention.py | 124 +++++++++++++++++++++---- 1 file changed, 107 insertions(+), 17 deletions(-) diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index d634605..29f0f1f 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -1,10 +1,13 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import math from typing import Optional import torch +import torch.nn.functional as F from einops import rearrange +from flash_attn import flash_attn_varlen_kvpacked_func from flash_attn.modules.mha import ( CrossAttention, FlashCrossAttention, @@ -113,7 +116,7 @@ class MHA(nn.Module): else: return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs) - def _forward(self, x, seqlen=None, inference_params=None, **kwargs): + def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint: disable=W0613 """ Arguments: x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. @@ -121,32 +124,119 @@ class MHA(nn.Module): split x during sequence parallel, we split the batch * seqlen dimension (in case batch is small). """ + bsz, _, _ = x.shape qkv = self.Wqkv(x) if seqlen is None: qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) else: qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim) - if self.rotary_emb_dim > 0: - kwargs["inference_params"] = inference_params - qkv = self.rotary_emb(qkv, **kwargs) + q, k, v = (x.squeeze(2) for x in qkv.chunk(chunks=3, dim=2)) if inference_params is None: - if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: - with torch.cuda.amp.autocast(dtype=torch.bfloat16): - if qkv.dtype not in [torch.float16, torch.bfloat16]: - qkv = qkv.to(torch.bfloat16) - context = self.inner_attn(qkv).to(x.dtype) - else: - context = self.inner_attn(qkv) + if self.rotary_emb_dim > 0: + q = self.rotary_emb._single_eval_forward(q) + k = self.rotary_emb._single_eval_forward(k) + kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) + context = self.inner_cross_attn(q, kv) + else: - q = qkv[:, :, 0] + assert self.rotary_emb_dim > 0 + if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: + empties = inference_params.attention_mask[..., -1].sum(dim=-1) + moved_q = q.clone() + moved_k = k.clone() + if inference_params.sequence_len_offset == 0: + for i in range(len(empties)): + if empties[i] != 0: + moved_q[i][: -empties[i]] = q[i][empties[i] :] + moved_k[i][: -empties[i]] = k[i][empties[i] :] + moved_q = self.rotary_emb._single_eval_forward( + moved_q, seqlen_offset=inference_params.sequence_len_offset + ) + moved_k = self.rotary_emb._single_eval_forward( + moved_k, seqlen_offset=inference_params.sequence_len_offset + ) + for i in range(len(empties)): + if empties[i] != 0: + q[i][empties[i] :] = moved_q[i][: -empties[i]] + k[i][empties[i] :] = moved_k[i][: -empties[i]] + else: + q[i] = moved_q[i] + k[i] = moved_k[i] + else: + q = q.squeeze(1) + k = k.squeeze(1) + q = self.rotary_emb._single_forward( + q, + inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) + - empties, + ).unsqueeze(1) + k = self.rotary_emb._single_forward( + k, + inference_params.sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) + - empties, + ).unsqueeze(1) + else: + q = self.rotary_emb._single_forward(q, inference_params.sequence_len_offset) + k = self.rotary_emb._single_forward(k, inference_params.sequence_len_offset) + + kv = torch.stack([k, v], dim=2) + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" - kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx) - # If we're processing the prompt, causal=None (use self.causal). - # If we're decoding, then causal=False. - causal = None if inference_params.sequence_len_offset == 0 else False - context = self.inner_cross_attn(q, kv, causal=causal) + kv = _update_kv_cache(kv, inference_params, self.layer_idx) + + if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: + if inference_params.sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen) + attn_mask = inference_params.attention_mask[:, None, ...] + attn_mask = torch.logical_or( + torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask + ) + attn_mask4flsh = ~attn_mask[:, :, -1, :].view(bsz, -1) + cu_seqlens = torch.concat( + [ + torch.tensor([0], dtype=torch.int32, device=attn_mask4flsh.device), + attn_mask4flsh.sum(dim=-1).to(dtype=torch.int32), + ], + dim=0, + ) + cu_seqlens = cu_seqlens.cumsum(dim=0, dtype=torch.int32) + max_seqlen_q = attn_mask4flsh.shape[-1] + max_seqlen_k = attn_mask4flsh.shape[-1] + total_q = q.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1]) + total_kv = kv.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1, 1)).view( + -1, kv.shape[-3], kv.shape[-2], kv.shape[-1] + ) + + output = flash_attn_varlen_kvpacked_func( + total_q, total_kv, cu_seqlens, cu_seqlens, max_seqlen_q, max_seqlen_k, 0.0, None, True, False + ) + + context = torch.zeros_like(q) + context = context.masked_scatter_(attn_mask4flsh.view(bsz, -1, 1, 1), output) + + else: + attn_mask = inference_params.attention_mask[:, -1, :].view(bsz, 1, 1, -1) + + k, v = torch.chunk(kv, 2, dim=2) + k = k.squeeze(2) + v = v.squeeze(2) + sp = k.shape + expansion = q.size(2) // k.size(2) + scores = torch.einsum( + "blhd,bnhd->bhln", + q, + k.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), + ) / math.sqrt(q.size(-1)) + scores = scores.masked_fill(attn_mask, -65000.0) + scores = F.softmax(scores, dim=-1) # bsz x h x L x L + context = torch.einsum( + "bhmn,bnhd->bmhd", + scores, + v.unsqueeze(3).expand(-1, -1, -1, expansion, -1).reshape(sp[0], sp[1], q.size(2), sp[3]), + ) + else: + context = self.inner_cross_attn(q, kv, causal=True) if seqlen is None: context = rearrange(context, "b s h d -> b s (h d)")