diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 6b92656..9836a1c 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -145,32 +145,36 @@ class MHA(nn.Module): else: qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim) - q, k, v = (x.squeeze(2) for x in qkv.chunk(chunks=3, dim=2)) - if inference_params is None: - if self.rotary_emb_dim > 0: - q = self.rotary_emb._single_eval_forward(q) - k = self.rotary_emb._single_eval_forward(k) - kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2) - context = self.inner_cross_attn(q, kv) + kwargs["inference_params"] = inference_params + qkv = self.rotary_emb(qkv, **kwargs) + if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + if qkv.dtype not in [torch.float16, torch.bfloat16]: + qkv = qkv.to(torch.bfloat16) + context = self.inner_attn(qkv).to(x.dtype) + else: + context = self.inner_attn(qkv) else: - assert self.rotary_emb_dim > 0 + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" + q, k, v = (x.squeeze(2) for x in qkv.chunk(chunks=3, dim=2)) + assert self.rotary_emb_dim > 0, "You should use rotary_emb." + if self.use_dynamic_ntk_rope: + kv = torch.stack([k, v], dim=2) + kv = _update_kv_cache(kv, inference_params, self.layer_idx) + if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: - empties = inference_params.attention_mask[..., -1].sum(dim=-1) - moved_q = q.clone() - moved_k = k.clone() if inference_params.sequence_len_offset == 0: + empties = inference_params.attention_mask[..., -1].sum(dim=-1) + moved_q = q.clone() + moved_k = k.clone() for i in range(len(empties)): if empties[i] != 0: moved_q[i][: -empties[i]] = q[i][empties[i] :] moved_k[i][: -empties[i]] = k[i][empties[i] :] - moved_q = self.rotary_emb._single_eval_forward( - moved_q, seqlen_offset=inference_params.sequence_len_offset - ) - moved_k = self.rotary_emb._single_eval_forward( - moved_k, seqlen_offset=inference_params.sequence_len_offset - ) + moved_q = self.rotary_emb._single_eval_forward(moved_q, seqlen_offset=0) + moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0) for i in range(len(empties)): if empties[i] != 0: q[i][empties[i] :] = moved_q[i][: -empties[i]] @@ -178,7 +182,12 @@ class MHA(nn.Module): else: q[i] = moved_q[i] k[i] = moved_k[i] - else: + elif not self.use_dynamic_ntk_rope: + if inference_params.sequence_len_offset > self.max_position_embeddings: + warnings.warn( + "Notice your prompt's length is longer than model's max_position_embeddings: " + f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations." + ) q = q.squeeze(1) k = k.squeeze(1) q = self.rotary_emb._single_forward( @@ -191,14 +200,31 @@ class MHA(nn.Module): inference_params.sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device) - empties, ).unsqueeze(1) + else: + q = q.squeeze(1) + q = self.rotary_emb._single_forward( + q, + inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device) + - empties, + ).unsqueeze(1) + empties = inference_params.attention_mask[..., -1].sum(dim=-1) + moved_k = k.clone() + for i in range(len(empties)): + if empties[i] != 0: + moved_k[i][: -empties[i]] = k[i][empties[i] :] + moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0) + for i in range(len(empties)): + if empties[i] != 0: + k[i][empties[i] :] = moved_k[i][: -empties[i]] + else: + k[i] = moved_k[i] else: q = self.rotary_emb._single_forward(q, inference_params.sequence_len_offset) k = self.rotary_emb._single_forward(k, inference_params.sequence_len_offset) - kv = torch.stack([k, v], dim=2) - - assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" - kv = _update_kv_cache(kv, inference_params, self.layer_idx) + if not self.use_dynamic_ntk_rope: + kv = torch.stack([k, v], dim=2) + kv = _update_kv_cache(kv, inference_params, self.layer_idx) if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None: if inference_params.sequence_len_offset == 0: # First entrance, attnmask (bs*seqlen*seqlen) @@ -222,9 +248,16 @@ class MHA(nn.Module): -1, kv.shape[-3], kv.shape[-2], kv.shape[-1] ) + if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: + with torch.cuda.amp.autocast(dtype=torch.bfloat16): + if total_q.dtype not in [torch.float16, torch.bfloat16]: + total_q = total_q.to(torch.bfloat16) + if total_kv.dtype not in [torch.float16, torch.bfloat16]: + total_kv = total_kv.to(torch.bfloat16) + output = flash_attn_varlen_kvpacked_func( total_q, total_kv, cu_seqlens, cu_seqlens, max_seqlen_q, max_seqlen_k, 0.0, None, True, False - ) + ).to(x.dtype) context = torch.zeros_like(q) context = context.masked_scatter_(attn_mask4flsh.view(bsz, -1, 1, 1), output)