mirror of https://github.com/InternLM/InternLM
Recover use_dynamic_ntk_rope.
parent
4a714966fc
commit
787e0e0940
|
@ -145,32 +145,36 @@ class MHA(nn.Module):
|
||||||
else:
|
else:
|
||||||
qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim)
|
qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim)
|
||||||
|
|
||||||
q, k, v = (x.squeeze(2) for x in qkv.chunk(chunks=3, dim=2))
|
|
||||||
|
|
||||||
if inference_params is None:
|
if inference_params is None:
|
||||||
if self.rotary_emb_dim > 0:
|
kwargs["inference_params"] = inference_params
|
||||||
q = self.rotary_emb._single_eval_forward(q)
|
qkv = self.rotary_emb(qkv, **kwargs)
|
||||||
k = self.rotary_emb._single_eval_forward(k)
|
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
|
||||||
kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2)
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
||||||
context = self.inner_cross_attn(q, kv)
|
if qkv.dtype not in [torch.float16, torch.bfloat16]:
|
||||||
|
qkv = qkv.to(torch.bfloat16)
|
||||||
|
context = self.inner_attn(qkv).to(x.dtype)
|
||||||
|
else:
|
||||||
|
context = self.inner_attn(qkv)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
assert self.rotary_emb_dim > 0
|
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
||||||
|
q, k, v = (x.squeeze(2) for x in qkv.chunk(chunks=3, dim=2))
|
||||||
|
assert self.rotary_emb_dim > 0, "You should use rotary_emb."
|
||||||
|
if self.use_dynamic_ntk_rope:
|
||||||
|
kv = torch.stack([k, v], dim=2)
|
||||||
|
kv = _update_kv_cache(kv, inference_params, self.layer_idx)
|
||||||
|
|
||||||
if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None:
|
if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None:
|
||||||
empties = inference_params.attention_mask[..., -1].sum(dim=-1)
|
|
||||||
moved_q = q.clone()
|
|
||||||
moved_k = k.clone()
|
|
||||||
if inference_params.sequence_len_offset == 0:
|
if inference_params.sequence_len_offset == 0:
|
||||||
|
empties = inference_params.attention_mask[..., -1].sum(dim=-1)
|
||||||
|
moved_q = q.clone()
|
||||||
|
moved_k = k.clone()
|
||||||
for i in range(len(empties)):
|
for i in range(len(empties)):
|
||||||
if empties[i] != 0:
|
if empties[i] != 0:
|
||||||
moved_q[i][: -empties[i]] = q[i][empties[i] :]
|
moved_q[i][: -empties[i]] = q[i][empties[i] :]
|
||||||
moved_k[i][: -empties[i]] = k[i][empties[i] :]
|
moved_k[i][: -empties[i]] = k[i][empties[i] :]
|
||||||
moved_q = self.rotary_emb._single_eval_forward(
|
moved_q = self.rotary_emb._single_eval_forward(moved_q, seqlen_offset=0)
|
||||||
moved_q, seqlen_offset=inference_params.sequence_len_offset
|
moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0)
|
||||||
)
|
|
||||||
moved_k = self.rotary_emb._single_eval_forward(
|
|
||||||
moved_k, seqlen_offset=inference_params.sequence_len_offset
|
|
||||||
)
|
|
||||||
for i in range(len(empties)):
|
for i in range(len(empties)):
|
||||||
if empties[i] != 0:
|
if empties[i] != 0:
|
||||||
q[i][empties[i] :] = moved_q[i][: -empties[i]]
|
q[i][empties[i] :] = moved_q[i][: -empties[i]]
|
||||||
|
@ -178,7 +182,12 @@ class MHA(nn.Module):
|
||||||
else:
|
else:
|
||||||
q[i] = moved_q[i]
|
q[i] = moved_q[i]
|
||||||
k[i] = moved_k[i]
|
k[i] = moved_k[i]
|
||||||
else:
|
elif not self.use_dynamic_ntk_rope:
|
||||||
|
if inference_params.sequence_len_offset > self.max_position_embeddings:
|
||||||
|
warnings.warn(
|
||||||
|
"Notice your prompt's length is longer than model's max_position_embeddings: "
|
||||||
|
f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations."
|
||||||
|
)
|
||||||
q = q.squeeze(1)
|
q = q.squeeze(1)
|
||||||
k = k.squeeze(1)
|
k = k.squeeze(1)
|
||||||
q = self.rotary_emb._single_forward(
|
q = self.rotary_emb._single_forward(
|
||||||
|
@ -191,14 +200,31 @@ class MHA(nn.Module):
|
||||||
inference_params.sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device)
|
inference_params.sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device)
|
||||||
- empties,
|
- empties,
|
||||||
).unsqueeze(1)
|
).unsqueeze(1)
|
||||||
|
else:
|
||||||
|
q = q.squeeze(1)
|
||||||
|
q = self.rotary_emb._single_forward(
|
||||||
|
q,
|
||||||
|
inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device)
|
||||||
|
- empties,
|
||||||
|
).unsqueeze(1)
|
||||||
|
empties = inference_params.attention_mask[..., -1].sum(dim=-1)
|
||||||
|
moved_k = k.clone()
|
||||||
|
for i in range(len(empties)):
|
||||||
|
if empties[i] != 0:
|
||||||
|
moved_k[i][: -empties[i]] = k[i][empties[i] :]
|
||||||
|
moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0)
|
||||||
|
for i in range(len(empties)):
|
||||||
|
if empties[i] != 0:
|
||||||
|
k[i][empties[i] :] = moved_k[i][: -empties[i]]
|
||||||
|
else:
|
||||||
|
k[i] = moved_k[i]
|
||||||
else:
|
else:
|
||||||
q = self.rotary_emb._single_forward(q, inference_params.sequence_len_offset)
|
q = self.rotary_emb._single_forward(q, inference_params.sequence_len_offset)
|
||||||
k = self.rotary_emb._single_forward(k, inference_params.sequence_len_offset)
|
k = self.rotary_emb._single_forward(k, inference_params.sequence_len_offset)
|
||||||
|
|
||||||
kv = torch.stack([k, v], dim=2)
|
if not self.use_dynamic_ntk_rope:
|
||||||
|
kv = torch.stack([k, v], dim=2)
|
||||||
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
kv = _update_kv_cache(kv, inference_params, self.layer_idx)
|
||||||
kv = _update_kv_cache(kv, inference_params, self.layer_idx)
|
|
||||||
|
|
||||||
if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None:
|
if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None:
|
||||||
if inference_params.sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen)
|
if inference_params.sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen)
|
||||||
|
@ -222,9 +248,16 @@ class MHA(nn.Module):
|
||||||
-1, kv.shape[-3], kv.shape[-2], kv.shape[-1]
|
-1, kv.shape[-3], kv.shape[-2], kv.shape[-1]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
|
||||||
|
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
||||||
|
if total_q.dtype not in [torch.float16, torch.bfloat16]:
|
||||||
|
total_q = total_q.to(torch.bfloat16)
|
||||||
|
if total_kv.dtype not in [torch.float16, torch.bfloat16]:
|
||||||
|
total_kv = total_kv.to(torch.bfloat16)
|
||||||
|
|
||||||
output = flash_attn_varlen_kvpacked_func(
|
output = flash_attn_varlen_kvpacked_func(
|
||||||
total_q, total_kv, cu_seqlens, cu_seqlens, max_seqlen_q, max_seqlen_k, 0.0, None, True, False
|
total_q, total_kv, cu_seqlens, cu_seqlens, max_seqlen_q, max_seqlen_k, 0.0, None, True, False
|
||||||
)
|
).to(x.dtype)
|
||||||
|
|
||||||
context = torch.zeros_like(q)
|
context = torch.zeros_like(q)
|
||||||
context = context.masked_scatter_(attn_mask4flsh.view(bsz, -1, 1, 1), output)
|
context = context.masked_scatter_(attn_mask4flsh.view(bsz, -1, 1, 1), output)
|
||||||
|
|
Loading…
Reference in New Issue