mirror of https://github.com/InternLM/InternLM
Fix bugs.
parent
787e0e0940
commit
78353e12cf
|
@ -157,72 +157,97 @@ class MHA(nn.Module):
|
|||
context = self.inner_attn(qkv)
|
||||
|
||||
else:
|
||||
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
||||
q, k, v = (x.squeeze(2) for x in qkv.chunk(chunks=3, dim=2))
|
||||
assert self.rotary_emb_dim > 0, "You should use rotary_emb."
|
||||
if self.use_dynamic_ntk_rope:
|
||||
kv = torch.stack([k, v], dim=2)
|
||||
kv = _update_kv_cache(kv, inference_params, self.layer_idx)
|
||||
|
||||
if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None:
|
||||
if inference_params.sequence_len_offset == 0:
|
||||
empties = inference_params.attention_mask[..., -1].sum(dim=-1)
|
||||
moved_q = q.clone()
|
||||
moved_k = k.clone()
|
||||
for i in range(len(empties)):
|
||||
if empties[i] != 0:
|
||||
moved_q[i][: -empties[i]] = q[i][empties[i] :]
|
||||
moved_k[i][: -empties[i]] = k[i][empties[i] :]
|
||||
moved_q = self.rotary_emb._single_eval_forward(moved_q, seqlen_offset=0)
|
||||
moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0)
|
||||
for i in range(len(empties)):
|
||||
if empties[i] != 0:
|
||||
q[i][empties[i] :] = moved_q[i][: -empties[i]]
|
||||
k[i][empties[i] :] = moved_k[i][: -empties[i]]
|
||||
else:
|
||||
q[i] = moved_q[i]
|
||||
k[i] = moved_k[i]
|
||||
elif not self.use_dynamic_ntk_rope:
|
||||
q = qkv[:, :, 0]
|
||||
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
||||
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
|
||||
if inference_params.sequence_len_offset != 0:
|
||||
# q shape: [bsz, 1, nheads, head_dim]
|
||||
# kv shape: [bsz, seqlen, 2, nheads, head_dim]
|
||||
bsz, seq_len, _, nheads, head_dim = kv.shape
|
||||
q = torch.cat([q.new_zeros(size=(bsz, seq_len - 1, nheads, head_dim)), q], dim=1).unsqueeze(2)
|
||||
qkv = torch.cat([q, kv], dim=2)
|
||||
if self.rotary_emb_dim > 0:
|
||||
qkv = self.rotary_emb(qkv)
|
||||
q = qkv[:, [-1], 0]
|
||||
kv = qkv[:, :, 1:]
|
||||
else:
|
||||
if inference_params.sequence_len_offset > self.max_position_embeddings:
|
||||
warnings.warn(
|
||||
"Notice your prompt's length is longer than model's max_position_embeddings: "
|
||||
f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations."
|
||||
)
|
||||
q = q.squeeze(1)
|
||||
k = k.squeeze(1)
|
||||
q = self.rotary_emb._single_forward(
|
||||
q,
|
||||
inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device)
|
||||
- empties,
|
||||
).unsqueeze(1)
|
||||
k = self.rotary_emb._single_forward(
|
||||
k,
|
||||
inference_params.sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device)
|
||||
- empties,
|
||||
).unsqueeze(1)
|
||||
else:
|
||||
q = q.squeeze(1)
|
||||
q = self.rotary_emb._single_forward(
|
||||
q,
|
||||
inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device)
|
||||
- empties,
|
||||
).unsqueeze(1)
|
||||
empties = inference_params.attention_mask[..., -1].sum(dim=-1)
|
||||
moved_k = k.clone()
|
||||
for i in range(len(empties)):
|
||||
if empties[i] != 0:
|
||||
moved_k[i][: -empties[i]] = k[i][empties[i] :]
|
||||
moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0)
|
||||
for i in range(len(empties)):
|
||||
if empties[i] != 0:
|
||||
k[i][empties[i] :] = moved_k[i][: -empties[i]]
|
||||
else:
|
||||
k[i] = moved_k[i]
|
||||
if self.rotary_emb_dim > 0:
|
||||
kwargs["inference_params"] = inference_params
|
||||
qkv = self.rotary_emb(qkv, **kwargs)
|
||||
q = qkv[:, :, 0]
|
||||
kv = qkv[:, :, 1:]
|
||||
else:
|
||||
q = self.rotary_emb._single_forward(q, inference_params.sequence_len_offset)
|
||||
k = self.rotary_emb._single_forward(k, inference_params.sequence_len_offset)
|
||||
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
||||
q, k, v = (x.squeeze(2) for x in qkv.chunk(chunks=3, dim=2))
|
||||
kv = torch.stack([k, v], dim=2)
|
||||
assert self.rotary_emb_dim > 0, "You should use rotary_emb."
|
||||
|
||||
if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None:
|
||||
empties = inference_params.attention_mask[..., -1].sum(dim=-1)
|
||||
if inference_params.sequence_len_offset == 0:
|
||||
moved_q = q.clone()
|
||||
moved_k = k.clone()
|
||||
for i in range(len(empties)):
|
||||
if empties[i] != 0:
|
||||
moved_q[i][: -empties[i]] = q[i][empties[i] :]
|
||||
moved_k[i][: -empties[i]] = k[i][empties[i] :]
|
||||
moved_q = self.rotary_emb._single_eval_forward(moved_q, seqlen_offset=0)
|
||||
moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0)
|
||||
for i in range(len(empties)):
|
||||
if empties[i] != 0:
|
||||
q[i][empties[i] :] = moved_q[i][: -empties[i]]
|
||||
k[i][empties[i] :] = moved_k[i][: -empties[i]]
|
||||
else:
|
||||
q[i] = moved_q[i]
|
||||
k[i] = moved_k[i]
|
||||
elif not self.use_dynamic_ntk_rope:
|
||||
if inference_params.sequence_len_offset > self.max_position_embeddings:
|
||||
warnings.warn(
|
||||
"Notice your prompt's length is longer than model's max_position_embeddings: "
|
||||
f"{self.max_position_embeddings}, may cause deviations in dynamic ntk calculations."
|
||||
)
|
||||
q = q.squeeze(1)
|
||||
k = k.squeeze(1)
|
||||
q = self.rotary_emb._single_forward(
|
||||
q,
|
||||
inference_params.sequence_len_offset
|
||||
* torch.ones(q.size(0), dtype=torch.int, device=q.device)
|
||||
- empties,
|
||||
).unsqueeze(1)
|
||||
k = self.rotary_emb._single_forward(
|
||||
k,
|
||||
inference_params.sequence_len_offset
|
||||
* torch.ones(k.size(0), dtype=torch.int, device=k.device)
|
||||
- empties,
|
||||
).unsqueeze(1)
|
||||
else:
|
||||
q = q.squeeze(1)
|
||||
q = self.rotary_emb._single_forward(
|
||||
q,
|
||||
inference_params.sequence_len_offset
|
||||
* torch.ones(q.size(0), dtype=torch.int, device=q.device)
|
||||
- empties,
|
||||
).unsqueeze(1)
|
||||
moved_k = k.clone()
|
||||
for i in range(len(empties)):
|
||||
if empties[i] != 0:
|
||||
moved_k[i][: -empties[i]] = k[i][empties[i] :]
|
||||
moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0)
|
||||
for i in range(len(empties)):
|
||||
if empties[i] != 0:
|
||||
k[i][empties[i] :] = moved_k[i][: -empties[i]]
|
||||
else:
|
||||
k[i] = moved_k[i]
|
||||
else:
|
||||
q = self.rotary_emb._single_forward(q, inference_params.sequence_len_offset)
|
||||
k = self.rotary_emb._single_forward(k, inference_params.sequence_len_offset)
|
||||
|
||||
if not self.use_dynamic_ntk_rope:
|
||||
kv = torch.stack([k, v], dim=2)
|
||||
kv = _update_kv_cache(kv, inference_params, self.layer_idx)
|
||||
|
||||
|
|
Loading…
Reference in New Issue