mirror of https://github.com/InternLM/InternLM
				
				
				
			merge develop
						commit
						dd67ab948d
					
				| 
						 | 
				
			
			@ -1,6 +1,7 @@
 | 
			
		|||
#!/usr/bin/env python
 | 
			
		||||
# -*- encoding: utf-8 -*-
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import Any, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -8,6 +9,22 @@ import torch
 | 
			
		|||
import torch.distributed as dist
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
from einops import rearrange
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from flash_attn.flash_attn_interface import flash_attn_unpadded_func
 | 
			
		||||
except ImportError:
 | 
			
		||||
    try:
 | 
			
		||||
        from flash_attn.flash_attn_interface import (
 | 
			
		||||
            flash_attn_unpadded_kvpacked_func as flash_attn_unpadded_func,
 | 
			
		||||
        )
 | 
			
		||||
    except ImportError:
 | 
			
		||||
        try:
 | 
			
		||||
            from flash_attn.flash_attn_interface import (
 | 
			
		||||
                flash_attn_varlen_kvpacked_func as flash_attn_unpadded_func,
 | 
			
		||||
            )
 | 
			
		||||
        except ImportError:
 | 
			
		||||
            raise ImportError("Please check your flash_attn version >= 1.0.5.")
 | 
			
		||||
 | 
			
		||||
from flash_attn.modules.mha import (
 | 
			
		||||
    CrossAttention,
 | 
			
		||||
    FlashCrossAttention,
 | 
			
		||||
| 
						 | 
				
			
			@ -229,7 +246,7 @@ class MHA(nn.Module):
 | 
			
		|||
        else:
 | 
			
		||||
            return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs)
 | 
			
		||||
 | 
			
		||||
    def _forward(self, x, seqlen=None, inference_params=None, **kwargs):
 | 
			
		||||
    def _forward(self, x, seqlen=None, inference_params=None, **kwargs):  # pylint: disable=W0613
 | 
			
		||||
        """
 | 
			
		||||
        Arguments:
 | 
			
		||||
            x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
 | 
			
		||||
| 
						 | 
				
			
			@ -237,6 +254,7 @@ class MHA(nn.Module):
 | 
			
		|||
                split x during sequence parallel, we split the batch * seqlen dimension
 | 
			
		||||
                (in case batch is small).
 | 
			
		||||
        """
 | 
			
		||||
        bsz, _, _ = x.shape
 | 
			
		||||
        qkv = self.Wqkv(x)
 | 
			
		||||
        if seqlen is None:
 | 
			
		||||
            qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
 | 
			
		||||
| 
						 | 
				
			
			@ -244,9 +262,8 @@ class MHA(nn.Module):
 | 
			
		|||
            qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim)
 | 
			
		||||
 | 
			
		||||
        if inference_params is None:
 | 
			
		||||
            if self.rotary_emb_dim > 0:
 | 
			
		||||
                kwargs["inference_params"] = inference_params
 | 
			
		||||
                qkv = self.rotary_emb(qkv, **kwargs)
 | 
			
		||||
            kwargs["inference_params"] = inference_params
 | 
			
		||||
            qkv = self.rotary_emb(qkv, **kwargs)
 | 
			
		||||
            if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
 | 
			
		||||
                with torch.cuda.amp.autocast(dtype=torch.bfloat16):
 | 
			
		||||
                    if qkv.dtype not in [torch.float16, torch.bfloat16]:
 | 
			
		||||
| 
						 | 
				
			
			@ -254,6 +271,7 @@ class MHA(nn.Module):
 | 
			
		|||
                    context = self.inner_attn(qkv).to(x.dtype)
 | 
			
		||||
            else:
 | 
			
		||||
                context = self.inner_attn(qkv)
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            if self.use_dynamic_ntk_rope:
 | 
			
		||||
                q = qkv[:, :, 0]
 | 
			
		||||
| 
						 | 
				
			
			@ -281,17 +299,131 @@ class MHA(nn.Module):
 | 
			
		|||
                        q = qkv[:, :, 0]
 | 
			
		||||
                        kv = qkv[:, :, 1:]
 | 
			
		||||
            else:
 | 
			
		||||
                if self.rotary_emb_dim > 0:
 | 
			
		||||
                    kwargs["inference_params"] = inference_params
 | 
			
		||||
                    qkv = self.rotary_emb(qkv, **kwargs)
 | 
			
		||||
                q = qkv[:, :, 0]
 | 
			
		||||
                assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
 | 
			
		||||
                kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
 | 
			
		||||
                q, k, v = (x.squeeze(2) for x in qkv.chunk(chunks=3, dim=2))
 | 
			
		||||
                kv = torch.stack([k, v], dim=2)
 | 
			
		||||
                assert self.rotary_emb_dim > 0, "You should use rotary_emb."
 | 
			
		||||
 | 
			
		||||
            # If we're processing the prompt, causal=None (use self.causal).
 | 
			
		||||
            # If we're decoding, then causal=False.
 | 
			
		||||
            causal = None if inference_params.sequence_len_offset == 0 else False
 | 
			
		||||
            context = self.inner_cross_attn(q, kv, causal=causal)
 | 
			
		||||
                if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None:
 | 
			
		||||
                    empties = inference_params.attention_mask[..., -1].sum(dim=-1)
 | 
			
		||||
                    if inference_params.sequence_len_offset == 0:
 | 
			
		||||
                        moved_q = q.clone()
 | 
			
		||||
                        moved_k = k.clone()
 | 
			
		||||
                        for i in range(len(empties)):
 | 
			
		||||
                            if empties[i] != 0:
 | 
			
		||||
                                moved_q[i][: -empties[i]] = q[i][empties[i] :]
 | 
			
		||||
                                moved_k[i][: -empties[i]] = k[i][empties[i] :]
 | 
			
		||||
                        moved_q = self.rotary_emb._single_eval_forward(moved_q, seqlen_offset=0)
 | 
			
		||||
                        moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0)
 | 
			
		||||
                        for i in range(len(empties)):
 | 
			
		||||
                            if empties[i] != 0:
 | 
			
		||||
                                q[i][empties[i] :] = moved_q[i][: -empties[i]]
 | 
			
		||||
                                k[i][empties[i] :] = moved_k[i][: -empties[i]]
 | 
			
		||||
                            else:
 | 
			
		||||
                                q[i] = moved_q[i]
 | 
			
		||||
                                k[i] = moved_k[i]
 | 
			
		||||
                    elif not self.use_dynamic_ntk_rope:
 | 
			
		||||
                        if inference_params.sequence_len_offset > self.max_position_embeddings:
 | 
			
		||||
                            warnings.warn(
 | 
			
		||||
                                "Notice your prompt's length is longer than model's max_position_embeddings: "
 | 
			
		||||
                                f"{self.max_position_embeddings}, may cause deviations in dynamic ntk calculations."
 | 
			
		||||
                            )
 | 
			
		||||
                        q = q.squeeze(1)
 | 
			
		||||
                        k = k.squeeze(1)
 | 
			
		||||
                        q = self.rotary_emb._single_forward(
 | 
			
		||||
                            q,
 | 
			
		||||
                            inference_params.sequence_len_offset
 | 
			
		||||
                            * torch.ones(q.size(0), dtype=torch.int, device=q.device)
 | 
			
		||||
                            - empties,
 | 
			
		||||
                        ).unsqueeze(1)
 | 
			
		||||
                        k = self.rotary_emb._single_forward(
 | 
			
		||||
                            k,
 | 
			
		||||
                            inference_params.sequence_len_offset
 | 
			
		||||
                            * torch.ones(k.size(0), dtype=torch.int, device=k.device)
 | 
			
		||||
                            - empties,
 | 
			
		||||
                        ).unsqueeze(1)
 | 
			
		||||
                    else:
 | 
			
		||||
                        q = q.squeeze(1)
 | 
			
		||||
                        q = self.rotary_emb._single_forward(
 | 
			
		||||
                            q,
 | 
			
		||||
                            inference_params.sequence_len_offset
 | 
			
		||||
                            * torch.ones(q.size(0), dtype=torch.int, device=q.device)
 | 
			
		||||
                            - empties,
 | 
			
		||||
                        ).unsqueeze(1)
 | 
			
		||||
                        moved_k = k.clone()
 | 
			
		||||
                        for i in range(len(empties)):
 | 
			
		||||
                            if empties[i] != 0:
 | 
			
		||||
                                moved_k[i][: -empties[i]] = k[i][empties[i] :]
 | 
			
		||||
                        moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0)
 | 
			
		||||
                        for i in range(len(empties)):
 | 
			
		||||
                            if empties[i] != 0:
 | 
			
		||||
                                k[i][empties[i] :] = moved_k[i][: -empties[i]]
 | 
			
		||||
                            else:
 | 
			
		||||
                                k[i] = moved_k[i]
 | 
			
		||||
                else:
 | 
			
		||||
                    q = self.rotary_emb._single_forward(q, inference_params.sequence_len_offset)
 | 
			
		||||
                    k = self.rotary_emb._single_forward(k, inference_params.sequence_len_offset)
 | 
			
		||||
 | 
			
		||||
                kv = torch.stack([k, v], dim=2)
 | 
			
		||||
                kv = _update_kv_cache(kv, inference_params, self.layer_idx)
 | 
			
		||||
 | 
			
		||||
            if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None:
 | 
			
		||||
                if inference_params.sequence_len_offset == 0:  # First entrance, attnmask (bs*seqlen*seqlen)
 | 
			
		||||
                    attn_mask = inference_params.attention_mask[:, None, ...]
 | 
			
		||||
                    attn_mask = torch.logical_or(
 | 
			
		||||
                        torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask
 | 
			
		||||
                    )
 | 
			
		||||
                    attn_mask4flsh = ~attn_mask[:, :, -1, :].view(bsz, -1)
 | 
			
		||||
                    cu_seqlens = torch.concat(
 | 
			
		||||
                        [
 | 
			
		||||
                            torch.tensor([0], dtype=torch.int32, device=attn_mask4flsh.device),
 | 
			
		||||
                            attn_mask4flsh.sum(dim=-1).to(dtype=torch.int32),
 | 
			
		||||
                        ],
 | 
			
		||||
                        dim=0,
 | 
			
		||||
                    )
 | 
			
		||||
                    cu_seqlens = cu_seqlens.cumsum(dim=0, dtype=torch.int32)
 | 
			
		||||
                    max_seqlen_q = attn_mask4flsh.shape[-1]
 | 
			
		||||
                    max_seqlen_k = attn_mask4flsh.shape[-1]
 | 
			
		||||
                    total_q = q.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1])
 | 
			
		||||
                    total_kv = kv.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1, 1)).view(
 | 
			
		||||
                        -1, kv.shape[-3], kv.shape[-2], kv.shape[-1]
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                    if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
 | 
			
		||||
                        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
 | 
			
		||||
                            if total_q.dtype not in [torch.float16, torch.bfloat16]:
 | 
			
		||||
                                total_q = total_q.to(torch.bfloat16)
 | 
			
		||||
                            if total_kv.dtype not in [torch.float16, torch.bfloat16]:
 | 
			
		||||
                                total_kv = total_kv.to(torch.bfloat16)
 | 
			
		||||
 | 
			
		||||
                    output = flash_attn_unpadded_func(
 | 
			
		||||
                        total_q, total_kv, cu_seqlens, cu_seqlens, max_seqlen_q, max_seqlen_k, 0.0, None, True, False
 | 
			
		||||
                    ).to(x.dtype)
 | 
			
		||||
 | 
			
		||||
                    context = torch.zeros_like(q)
 | 
			
		||||
                    context = context.masked_scatter_(attn_mask4flsh.view(bsz, -1, 1, 1), output)
 | 
			
		||||
 | 
			
		||||
                else:
 | 
			
		||||
                    attn_mask = inference_params.attention_mask[:, -1, :].view(bsz, 1, 1, -1)
 | 
			
		||||
 | 
			
		||||
                    k, v = torch.chunk(kv, 2, dim=2)
 | 
			
		||||
                    k = k.squeeze(2)
 | 
			
		||||
                    v = v.squeeze(2)
 | 
			
		||||
                    sp = k.shape
 | 
			
		||||
                    scores = torch.einsum(
 | 
			
		||||
                        "blhd,bnhd->bhln",
 | 
			
		||||
                        q,
 | 
			
		||||
                        k.reshape(sp[0], sp[1], q.size(2), sp[3]),
 | 
			
		||||
                    ) / math.sqrt(q.size(-1))
 | 
			
		||||
                    scores = scores.masked_fill(attn_mask, -65000.0)
 | 
			
		||||
                    scores = F.softmax(scores, dim=-1)  # bsz x h x L x L
 | 
			
		||||
                    context = torch.einsum(
 | 
			
		||||
                        "bhmn,bnhd->bmhd",
 | 
			
		||||
                        scores,
 | 
			
		||||
                        v.reshape(sp[0], sp[1], q.size(2), sp[3]),
 | 
			
		||||
                    )
 | 
			
		||||
            else:
 | 
			
		||||
                context = self.inner_cross_attn(q, kv, causal=True)
 | 
			
		||||
 | 
			
		||||
        if seqlen is None:
 | 
			
		||||
            context = rearrange(context, "b s h d -> b s (h d)")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue