mirror of https://github.com/InternLM/InternLM
fix(model): fix errant inference_forward (#396)
* Fix errant inference_forward. * Recover use_dynamic_ntk_rope. * Fix bugs. * Fit to flash attention 1.0 * Fit to flash attention 1.0 * Fit to flash attention 1.0.5. * Fit to flash attention 1.0.5.pull/408/head
parent
a075153adf
commit
b3645b0244
|
@ -1,11 +1,29 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
try:
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
from flash_attn.flash_attn_interface import (
|
||||||
|
flash_attn_unpadded_kvpacked_func as flash_attn_unpadded_func,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
from flash_attn.flash_attn_interface import (
|
||||||
|
flash_attn_varlen_kvpacked_func as flash_attn_unpadded_func,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please check your flash_attn version >= 1.0.5.")
|
||||||
|
|
||||||
from flash_attn.modules.mha import (
|
from flash_attn.modules.mha import (
|
||||||
CrossAttention,
|
CrossAttention,
|
||||||
FlashCrossAttention,
|
FlashCrossAttention,
|
||||||
|
@ -127,7 +145,7 @@ class MHA(nn.Module):
|
||||||
else:
|
else:
|
||||||
return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs)
|
return self._forward(x=x, seqlen=seqlen, inference_params=inference_params, **kwargs)
|
||||||
|
|
||||||
def _forward(self, x, seqlen=None, inference_params=None, **kwargs):
|
def _forward(self, x, seqlen=None, inference_params=None, **kwargs): # pylint: disable=W0613
|
||||||
"""
|
"""
|
||||||
Arguments:
|
Arguments:
|
||||||
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
|
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
|
||||||
|
@ -135,6 +153,7 @@ class MHA(nn.Module):
|
||||||
split x during sequence parallel, we split the batch * seqlen dimension
|
split x during sequence parallel, we split the batch * seqlen dimension
|
||||||
(in case batch is small).
|
(in case batch is small).
|
||||||
"""
|
"""
|
||||||
|
bsz, _, _ = x.shape
|
||||||
qkv = self.Wqkv(x)
|
qkv = self.Wqkv(x)
|
||||||
if seqlen is None:
|
if seqlen is None:
|
||||||
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
|
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
|
||||||
|
@ -142,9 +161,8 @@ class MHA(nn.Module):
|
||||||
qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim)
|
qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim)
|
||||||
|
|
||||||
if inference_params is None:
|
if inference_params is None:
|
||||||
if self.rotary_emb_dim > 0:
|
kwargs["inference_params"] = inference_params
|
||||||
kwargs["inference_params"] = inference_params
|
qkv = self.rotary_emb(qkv, **kwargs)
|
||||||
qkv = self.rotary_emb(qkv, **kwargs)
|
|
||||||
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
|
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
|
||||||
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
||||||
if qkv.dtype not in [torch.float16, torch.bfloat16]:
|
if qkv.dtype not in [torch.float16, torch.bfloat16]:
|
||||||
|
@ -152,6 +170,7 @@ class MHA(nn.Module):
|
||||||
context = self.inner_attn(qkv).to(x.dtype)
|
context = self.inner_attn(qkv).to(x.dtype)
|
||||||
else:
|
else:
|
||||||
context = self.inner_attn(qkv)
|
context = self.inner_attn(qkv)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if self.use_dynamic_ntk_rope:
|
if self.use_dynamic_ntk_rope:
|
||||||
q = qkv[:, :, 0]
|
q = qkv[:, :, 0]
|
||||||
|
@ -179,17 +198,131 @@ class MHA(nn.Module):
|
||||||
q = qkv[:, :, 0]
|
q = qkv[:, :, 0]
|
||||||
kv = qkv[:, :, 1:]
|
kv = qkv[:, :, 1:]
|
||||||
else:
|
else:
|
||||||
if self.rotary_emb_dim > 0:
|
|
||||||
kwargs["inference_params"] = inference_params
|
|
||||||
qkv = self.rotary_emb(qkv, **kwargs)
|
|
||||||
q = qkv[:, :, 0]
|
|
||||||
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
||||||
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
|
q, k, v = (x.squeeze(2) for x in qkv.chunk(chunks=3, dim=2))
|
||||||
|
kv = torch.stack([k, v], dim=2)
|
||||||
|
assert self.rotary_emb_dim > 0, "You should use rotary_emb."
|
||||||
|
|
||||||
# If we're processing the prompt, causal=None (use self.causal).
|
if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None:
|
||||||
# If we're decoding, then causal=False.
|
empties = inference_params.attention_mask[..., -1].sum(dim=-1)
|
||||||
causal = None if inference_params.sequence_len_offset == 0 else False
|
if inference_params.sequence_len_offset == 0:
|
||||||
context = self.inner_cross_attn(q, kv, causal=causal)
|
moved_q = q.clone()
|
||||||
|
moved_k = k.clone()
|
||||||
|
for i in range(len(empties)):
|
||||||
|
if empties[i] != 0:
|
||||||
|
moved_q[i][: -empties[i]] = q[i][empties[i] :]
|
||||||
|
moved_k[i][: -empties[i]] = k[i][empties[i] :]
|
||||||
|
moved_q = self.rotary_emb._single_eval_forward(moved_q, seqlen_offset=0)
|
||||||
|
moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0)
|
||||||
|
for i in range(len(empties)):
|
||||||
|
if empties[i] != 0:
|
||||||
|
q[i][empties[i] :] = moved_q[i][: -empties[i]]
|
||||||
|
k[i][empties[i] :] = moved_k[i][: -empties[i]]
|
||||||
|
else:
|
||||||
|
q[i] = moved_q[i]
|
||||||
|
k[i] = moved_k[i]
|
||||||
|
elif not self.use_dynamic_ntk_rope:
|
||||||
|
if inference_params.sequence_len_offset > self.max_position_embeddings:
|
||||||
|
warnings.warn(
|
||||||
|
"Notice your prompt's length is longer than model's max_position_embeddings: "
|
||||||
|
f"{self.max_position_embeddings}, may cause deviations in dynamic ntk calculations."
|
||||||
|
)
|
||||||
|
q = q.squeeze(1)
|
||||||
|
k = k.squeeze(1)
|
||||||
|
q = self.rotary_emb._single_forward(
|
||||||
|
q,
|
||||||
|
inference_params.sequence_len_offset
|
||||||
|
* torch.ones(q.size(0), dtype=torch.int, device=q.device)
|
||||||
|
- empties,
|
||||||
|
).unsqueeze(1)
|
||||||
|
k = self.rotary_emb._single_forward(
|
||||||
|
k,
|
||||||
|
inference_params.sequence_len_offset
|
||||||
|
* torch.ones(k.size(0), dtype=torch.int, device=k.device)
|
||||||
|
- empties,
|
||||||
|
).unsqueeze(1)
|
||||||
|
else:
|
||||||
|
q = q.squeeze(1)
|
||||||
|
q = self.rotary_emb._single_forward(
|
||||||
|
q,
|
||||||
|
inference_params.sequence_len_offset
|
||||||
|
* torch.ones(q.size(0), dtype=torch.int, device=q.device)
|
||||||
|
- empties,
|
||||||
|
).unsqueeze(1)
|
||||||
|
moved_k = k.clone()
|
||||||
|
for i in range(len(empties)):
|
||||||
|
if empties[i] != 0:
|
||||||
|
moved_k[i][: -empties[i]] = k[i][empties[i] :]
|
||||||
|
moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0)
|
||||||
|
for i in range(len(empties)):
|
||||||
|
if empties[i] != 0:
|
||||||
|
k[i][empties[i] :] = moved_k[i][: -empties[i]]
|
||||||
|
else:
|
||||||
|
k[i] = moved_k[i]
|
||||||
|
else:
|
||||||
|
q = self.rotary_emb._single_forward(q, inference_params.sequence_len_offset)
|
||||||
|
k = self.rotary_emb._single_forward(k, inference_params.sequence_len_offset)
|
||||||
|
|
||||||
|
kv = torch.stack([k, v], dim=2)
|
||||||
|
kv = _update_kv_cache(kv, inference_params, self.layer_idx)
|
||||||
|
|
||||||
|
if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None:
|
||||||
|
if inference_params.sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen)
|
||||||
|
attn_mask = inference_params.attention_mask[:, None, ...]
|
||||||
|
attn_mask = torch.logical_or(
|
||||||
|
torch.ones_like(attn_mask, dtype=torch.bool).triu(diagonal=1), attn_mask
|
||||||
|
)
|
||||||
|
attn_mask4flsh = ~attn_mask[:, :, -1, :].view(bsz, -1)
|
||||||
|
cu_seqlens = torch.concat(
|
||||||
|
[
|
||||||
|
torch.tensor([0], dtype=torch.int32, device=attn_mask4flsh.device),
|
||||||
|
attn_mask4flsh.sum(dim=-1).to(dtype=torch.int32),
|
||||||
|
],
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
cu_seqlens = cu_seqlens.cumsum(dim=0, dtype=torch.int32)
|
||||||
|
max_seqlen_q = attn_mask4flsh.shape[-1]
|
||||||
|
max_seqlen_k = attn_mask4flsh.shape[-1]
|
||||||
|
total_q = q.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1)).view(-1, q.shape[-2], q.shape[-1])
|
||||||
|
total_kv = kv.masked_select(attn_mask4flsh.view(bsz, -1, 1, 1, 1)).view(
|
||||||
|
-1, kv.shape[-3], kv.shape[-2], kv.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
|
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
|
||||||
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
||||||
|
if total_q.dtype not in [torch.float16, torch.bfloat16]:
|
||||||
|
total_q = total_q.to(torch.bfloat16)
|
||||||
|
if total_kv.dtype not in [torch.float16, torch.bfloat16]:
|
||||||
|
total_kv = total_kv.to(torch.bfloat16)
|
||||||
|
|
||||||
|
output = flash_attn_unpadded_func(
|
||||||
|
total_q, total_kv, cu_seqlens, cu_seqlens, max_seqlen_q, max_seqlen_k, 0.0, None, True, False
|
||||||
|
).to(x.dtype)
|
||||||
|
|
||||||
|
context = torch.zeros_like(q)
|
||||||
|
context = context.masked_scatter_(attn_mask4flsh.view(bsz, -1, 1, 1), output)
|
||||||
|
|
||||||
|
else:
|
||||||
|
attn_mask = inference_params.attention_mask[:, -1, :].view(bsz, 1, 1, -1)
|
||||||
|
|
||||||
|
k, v = torch.chunk(kv, 2, dim=2)
|
||||||
|
k = k.squeeze(2)
|
||||||
|
v = v.squeeze(2)
|
||||||
|
sp = k.shape
|
||||||
|
scores = torch.einsum(
|
||||||
|
"blhd,bnhd->bhln",
|
||||||
|
q,
|
||||||
|
k.reshape(sp[0], sp[1], q.size(2), sp[3]),
|
||||||
|
) / math.sqrt(q.size(-1))
|
||||||
|
scores = scores.masked_fill(attn_mask, -65000.0)
|
||||||
|
scores = F.softmax(scores, dim=-1) # bsz x h x L x L
|
||||||
|
context = torch.einsum(
|
||||||
|
"bhmn,bnhd->bmhd",
|
||||||
|
scores,
|
||||||
|
v.reshape(sp[0], sp[1], q.size(2), sp[3]),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
context = self.inner_cross_attn(q, kv, causal=True)
|
||||||
|
|
||||||
if seqlen is None:
|
if seqlen is None:
|
||||||
context = rearrange(context, "b s h d -> b s (h d)")
|
context = rearrange(context, "b s h d -> b s (h d)")
|
||||||
|
|
Loading…
Reference in New Issue