From 78353e12cfd9a9889fb8051a8d52564b34c6a576 Mon Sep 17 00:00:00 2001 From: Pryest <495945214@qq.com> Date: Mon, 9 Oct 2023 20:27:03 +0800 Subject: [PATCH] Fix bugs. --- internlm/model/multi_head_attention.py | 141 +++++++++++++++---------- 1 file changed, 83 insertions(+), 58 deletions(-) diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 9836a1c..6c611d8 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -157,72 +157,97 @@ class MHA(nn.Module): context = self.inner_attn(qkv) else: - assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" - q, k, v = (x.squeeze(2) for x in qkv.chunk(chunks=3, dim=2)) - assert self.rotary_emb_dim > 0, "You should use rotary_emb." if self.use_dynamic_ntk_rope: - kv = torch.stack([k, v], dim=2) - kv = _update_kv_cache(kv, inference_params, self.layer_idx) - - if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: - if inference_params.sequence_len_offset == 0: - empties = inference_params.attention_mask[..., -1].sum(dim=-1) - moved_q = q.clone() - moved_k = k.clone() - for i in range(len(empties)): - if empties[i] != 0: - moved_q[i][: -empties[i]] = q[i][empties[i] :] - moved_k[i][: -empties[i]] = k[i][empties[i] :] - moved_q = self.rotary_emb._single_eval_forward(moved_q, seqlen_offset=0) - moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0) - for i in range(len(empties)): - if empties[i] != 0: - q[i][empties[i] :] = moved_q[i][: -empties[i]] - k[i][empties[i] :] = moved_k[i][: -empties[i]] - else: - q[i] = moved_q[i] - k[i] = moved_k[i] - elif not self.use_dynamic_ntk_rope: + q = qkv[:, :, 0] + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" + kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx) + if inference_params.sequence_len_offset != 0: + # q shape: [bsz, 1, nheads, head_dim] + # kv shape: [bsz, seqlen, 2, nheads, head_dim] + bsz, seq_len, _, nheads, head_dim = kv.shape + q = torch.cat([q.new_zeros(size=(bsz, seq_len - 1, nheads, head_dim)), q], dim=1).unsqueeze(2) + qkv = torch.cat([q, kv], dim=2) + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb(qkv) + q = qkv[:, [-1], 0] + kv = qkv[:, :, 1:] + else: if inference_params.sequence_len_offset > self.max_position_embeddings: warnings.warn( "Notice your prompt's length is longer than model's max_position_embeddings: " f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations." ) - q = q.squeeze(1) - k = k.squeeze(1) - q = self.rotary_emb._single_forward( - q, - inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) - - empties, - ).unsqueeze(1) - k = self.rotary_emb._single_forward( - k, - inference_params.sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) - - empties, - ).unsqueeze(1) - else: - q = q.squeeze(1) - q = self.rotary_emb._single_forward( - q, - inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) - - empties, - ).unsqueeze(1) - empties = inference_params.attention_mask[..., -1].sum(dim=-1) - moved_k = k.clone() - for i in range(len(empties)): - if empties[i] != 0: - moved_k[i][: -empties[i]] = k[i][empties[i] :] - moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0) - for i in range(len(empties)): - if empties[i] != 0: - k[i][empties[i] :] = moved_k[i][: -empties[i]] - else: - k[i] = moved_k[i] + if self.rotary_emb_dim > 0: + kwargs["inference_params"] = inference_params + qkv = self.rotary_emb(qkv, **kwargs) + q = qkv[:, :, 0] + kv = qkv[:, :, 1:] else: - q = self.rotary_emb._single_forward(q, inference_params.sequence_len_offset) - k = self.rotary_emb._single_forward(k, inference_params.sequence_len_offset) + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" + q, k, v = (x.squeeze(2) for x in qkv.chunk(chunks=3, dim=2)) + kv = torch.stack([k, v], dim=2) + assert self.rotary_emb_dim > 0, "You should use rotary_emb." + + if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: + empties = inference_params.attention_mask[..., -1].sum(dim=-1) + if inference_params.sequence_len_offset == 0: + moved_q = q.clone() + moved_k = k.clone() + for i in range(len(empties)): + if empties[i] != 0: + moved_q[i][: -empties[i]] = q[i][empties[i] :] + moved_k[i][: -empties[i]] = k[i][empties[i] :] + moved_q = self.rotary_emb._single_eval_forward(moved_q, seqlen_offset=0) + moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0) + for i in range(len(empties)): + if empties[i] != 0: + q[i][empties[i] :] = moved_q[i][: -empties[i]] + k[i][empties[i] :] = moved_k[i][: -empties[i]] + else: + q[i] = moved_q[i] + k[i] = moved_k[i] + elif not self.use_dynamic_ntk_rope: + if inference_params.sequence_len_offset > self.max_position_embeddings: + warnings.warn( + "Notice your prompt's length is longer than model's max_position_embeddings: " + f"{self.max_position_embeddings}, may cause deviations in dynamic ntk calculations." + ) + q = q.squeeze(1) + k = k.squeeze(1) + q = self.rotary_emb._single_forward( + q, + inference_params.sequence_len_offset + * torch.ones(q.size(0), dtype=torch.int, device=q.device) + - empties, + ).unsqueeze(1) + k = self.rotary_emb._single_forward( + k, + inference_params.sequence_len_offset + * torch.ones(k.size(0), dtype=torch.int, device=k.device) + - empties, + ).unsqueeze(1) + else: + q = q.squeeze(1) + q = self.rotary_emb._single_forward( + q, + inference_params.sequence_len_offset + * torch.ones(q.size(0), dtype=torch.int, device=q.device) + - empties, + ).unsqueeze(1) + moved_k = k.clone() + for i in range(len(empties)): + if empties[i] != 0: + moved_k[i][: -empties[i]] = k[i][empties[i] :] + moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0) + for i in range(len(empties)): + if empties[i] != 0: + k[i][empties[i] :] = moved_k[i][: -empties[i]] + else: + k[i] = moved_k[i] + else: + q = self.rotary_emb._single_forward(q, inference_params.sequence_len_offset) + k = self.rotary_emb._single_forward(k, inference_params.sequence_len_offset) - if not self.use_dynamic_ntk_rope: kv = torch.stack([k, v], dim=2) kv = _update_kv_cache(kv, inference_params, self.layer_idx)