mirror of https://github.com/InternLM/InternLM
Fix bugs.
parent
787e0e0940
commit
78353e12cf
|
@ -156,17 +156,41 @@ class MHA(nn.Module):
|
||||||
else:
|
else:
|
||||||
context = self.inner_attn(qkv)
|
context = self.inner_attn(qkv)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if self.use_dynamic_ntk_rope:
|
||||||
|
q = qkv[:, :, 0]
|
||||||
|
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
||||||
|
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
|
||||||
|
if inference_params.sequence_len_offset != 0:
|
||||||
|
# q shape: [bsz, 1, nheads, head_dim]
|
||||||
|
# kv shape: [bsz, seqlen, 2, nheads, head_dim]
|
||||||
|
bsz, seq_len, _, nheads, head_dim = kv.shape
|
||||||
|
q = torch.cat([q.new_zeros(size=(bsz, seq_len - 1, nheads, head_dim)), q], dim=1).unsqueeze(2)
|
||||||
|
qkv = torch.cat([q, kv], dim=2)
|
||||||
|
if self.rotary_emb_dim > 0:
|
||||||
|
qkv = self.rotary_emb(qkv)
|
||||||
|
q = qkv[:, [-1], 0]
|
||||||
|
kv = qkv[:, :, 1:]
|
||||||
|
else:
|
||||||
|
if inference_params.sequence_len_offset > self.max_position_embeddings:
|
||||||
|
warnings.warn(
|
||||||
|
"Notice your prompt's length is longer than model's max_position_embeddings: "
|
||||||
|
f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations."
|
||||||
|
)
|
||||||
|
if self.rotary_emb_dim > 0:
|
||||||
|
kwargs["inference_params"] = inference_params
|
||||||
|
qkv = self.rotary_emb(qkv, **kwargs)
|
||||||
|
q = qkv[:, :, 0]
|
||||||
|
kv = qkv[:, :, 1:]
|
||||||
else:
|
else:
|
||||||
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
||||||
q, k, v = (x.squeeze(2) for x in qkv.chunk(chunks=3, dim=2))
|
q, k, v = (x.squeeze(2) for x in qkv.chunk(chunks=3, dim=2))
|
||||||
assert self.rotary_emb_dim > 0, "You should use rotary_emb."
|
|
||||||
if self.use_dynamic_ntk_rope:
|
|
||||||
kv = torch.stack([k, v], dim=2)
|
kv = torch.stack([k, v], dim=2)
|
||||||
kv = _update_kv_cache(kv, inference_params, self.layer_idx)
|
assert self.rotary_emb_dim > 0, "You should use rotary_emb."
|
||||||
|
|
||||||
if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None:
|
if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None:
|
||||||
if inference_params.sequence_len_offset == 0:
|
|
||||||
empties = inference_params.attention_mask[..., -1].sum(dim=-1)
|
empties = inference_params.attention_mask[..., -1].sum(dim=-1)
|
||||||
|
if inference_params.sequence_len_offset == 0:
|
||||||
moved_q = q.clone()
|
moved_q = q.clone()
|
||||||
moved_k = k.clone()
|
moved_k = k.clone()
|
||||||
for i in range(len(empties)):
|
for i in range(len(empties)):
|
||||||
|
@ -186,28 +210,30 @@ class MHA(nn.Module):
|
||||||
if inference_params.sequence_len_offset > self.max_position_embeddings:
|
if inference_params.sequence_len_offset > self.max_position_embeddings:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Notice your prompt's length is longer than model's max_position_embeddings: "
|
"Notice your prompt's length is longer than model's max_position_embeddings: "
|
||||||
f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations."
|
f"{self.max_position_embeddings}, may cause deviations in dynamic ntk calculations."
|
||||||
)
|
)
|
||||||
q = q.squeeze(1)
|
q = q.squeeze(1)
|
||||||
k = k.squeeze(1)
|
k = k.squeeze(1)
|
||||||
q = self.rotary_emb._single_forward(
|
q = self.rotary_emb._single_forward(
|
||||||
q,
|
q,
|
||||||
inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device)
|
inference_params.sequence_len_offset
|
||||||
|
* torch.ones(q.size(0), dtype=torch.int, device=q.device)
|
||||||
- empties,
|
- empties,
|
||||||
).unsqueeze(1)
|
).unsqueeze(1)
|
||||||
k = self.rotary_emb._single_forward(
|
k = self.rotary_emb._single_forward(
|
||||||
k,
|
k,
|
||||||
inference_params.sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device)
|
inference_params.sequence_len_offset
|
||||||
|
* torch.ones(k.size(0), dtype=torch.int, device=k.device)
|
||||||
- empties,
|
- empties,
|
||||||
).unsqueeze(1)
|
).unsqueeze(1)
|
||||||
else:
|
else:
|
||||||
q = q.squeeze(1)
|
q = q.squeeze(1)
|
||||||
q = self.rotary_emb._single_forward(
|
q = self.rotary_emb._single_forward(
|
||||||
q,
|
q,
|
||||||
inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device)
|
inference_params.sequence_len_offset
|
||||||
|
* torch.ones(q.size(0), dtype=torch.int, device=q.device)
|
||||||
- empties,
|
- empties,
|
||||||
).unsqueeze(1)
|
).unsqueeze(1)
|
||||||
empties = inference_params.attention_mask[..., -1].sum(dim=-1)
|
|
||||||
moved_k = k.clone()
|
moved_k = k.clone()
|
||||||
for i in range(len(empties)):
|
for i in range(len(empties)):
|
||||||
if empties[i] != 0:
|
if empties[i] != 0:
|
||||||
|
@ -222,7 +248,6 @@ class MHA(nn.Module):
|
||||||
q = self.rotary_emb._single_forward(q, inference_params.sequence_len_offset)
|
q = self.rotary_emb._single_forward(q, inference_params.sequence_len_offset)
|
||||||
k = self.rotary_emb._single_forward(k, inference_params.sequence_len_offset)
|
k = self.rotary_emb._single_forward(k, inference_params.sequence_len_offset)
|
||||||
|
|
||||||
if not self.use_dynamic_ntk_rope:
|
|
||||||
kv = torch.stack([k, v], dim=2)
|
kv = torch.stack([k, v], dim=2)
|
||||||
kv = _update_kv_cache(kv, inference_params, self.layer_idx)
|
kv = _update_kv_cache(kv, inference_params, self.layer_idx)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue