Fix bugs.

pull/396/head
Pryest 2023-10-09 20:27:03 +08:00
parent 787e0e0940
commit 78353e12cf
1 changed files with 83 additions and 58 deletions

View File

@ -156,17 +156,41 @@ class MHA(nn.Module):
else: else:
context = self.inner_attn(qkv) context = self.inner_attn(qkv)
else:
if self.use_dynamic_ntk_rope:
q = qkv[:, :, 0]
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
if inference_params.sequence_len_offset != 0:
# q shape: [bsz, 1, nheads, head_dim]
# kv shape: [bsz, seqlen, 2, nheads, head_dim]
bsz, seq_len, _, nheads, head_dim = kv.shape
q = torch.cat([q.new_zeros(size=(bsz, seq_len - 1, nheads, head_dim)), q], dim=1).unsqueeze(2)
qkv = torch.cat([q, kv], dim=2)
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv)
q = qkv[:, [-1], 0]
kv = qkv[:, :, 1:]
else:
if inference_params.sequence_len_offset > self.max_position_embeddings:
warnings.warn(
"Notice your prompt's length is longer than model's max_position_embeddings: "
f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations."
)
if self.rotary_emb_dim > 0:
kwargs["inference_params"] = inference_params
qkv = self.rotary_emb(qkv, **kwargs)
q = qkv[:, :, 0]
kv = qkv[:, :, 1:]
else: else:
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
q, k, v = (x.squeeze(2) for x in qkv.chunk(chunks=3, dim=2)) q, k, v = (x.squeeze(2) for x in qkv.chunk(chunks=3, dim=2))
assert self.rotary_emb_dim > 0, "You should use rotary_emb."
if self.use_dynamic_ntk_rope:
kv = torch.stack([k, v], dim=2) kv = torch.stack([k, v], dim=2)
kv = _update_kv_cache(kv, inference_params, self.layer_idx) assert self.rotary_emb_dim > 0, "You should use rotary_emb."
if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None:
if inference_params.sequence_len_offset == 0:
empties = inference_params.attention_mask[..., -1].sum(dim=-1) empties = inference_params.attention_mask[..., -1].sum(dim=-1)
if inference_params.sequence_len_offset == 0:
moved_q = q.clone() moved_q = q.clone()
moved_k = k.clone() moved_k = k.clone()
for i in range(len(empties)): for i in range(len(empties)):
@ -186,28 +210,30 @@ class MHA(nn.Module):
if inference_params.sequence_len_offset > self.max_position_embeddings: if inference_params.sequence_len_offset > self.max_position_embeddings:
warnings.warn( warnings.warn(
"Notice your prompt's length is longer than model's max_position_embeddings: " "Notice your prompt's length is longer than model's max_position_embeddings: "
f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations." f"{self.max_position_embeddings}, may cause deviations in dynamic ntk calculations."
) )
q = q.squeeze(1) q = q.squeeze(1)
k = k.squeeze(1) k = k.squeeze(1)
q = self.rotary_emb._single_forward( q = self.rotary_emb._single_forward(
q, q,
inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) inference_params.sequence_len_offset
* torch.ones(q.size(0), dtype=torch.int, device=q.device)
- empties, - empties,
).unsqueeze(1) ).unsqueeze(1)
k = self.rotary_emb._single_forward( k = self.rotary_emb._single_forward(
k, k,
inference_params.sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) inference_params.sequence_len_offset
* torch.ones(k.size(0), dtype=torch.int, device=k.device)
- empties, - empties,
).unsqueeze(1) ).unsqueeze(1)
else: else:
q = q.squeeze(1) q = q.squeeze(1)
q = self.rotary_emb._single_forward( q = self.rotary_emb._single_forward(
q, q,
inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) inference_params.sequence_len_offset
* torch.ones(q.size(0), dtype=torch.int, device=q.device)
- empties, - empties,
).unsqueeze(1) ).unsqueeze(1)
empties = inference_params.attention_mask[..., -1].sum(dim=-1)
moved_k = k.clone() moved_k = k.clone()
for i in range(len(empties)): for i in range(len(empties)):
if empties[i] != 0: if empties[i] != 0:
@ -222,7 +248,6 @@ class MHA(nn.Module):
q = self.rotary_emb._single_forward(q, inference_params.sequence_len_offset) q = self.rotary_emb._single_forward(q, inference_params.sequence_len_offset)
k = self.rotary_emb._single_forward(k, inference_params.sequence_len_offset) k = self.rotary_emb._single_forward(k, inference_params.sequence_len_offset)
if not self.use_dynamic_ntk_rope:
kv = torch.stack([k, v], dim=2) kv = torch.stack([k, v], dim=2)
kv = _update_kv_cache(kv, inference_params, self.layer_idx) kv = _update_kv_cache(kv, inference_params, self.layer_idx)