Recover use_dynamic_ntk_rope.

pull/396/head
Pryest 2023-10-07 20:59:47 +08:00
parent 4a714966fc
commit 787e0e0940
1 changed files with 56 additions and 23 deletions

View File

@ -145,32 +145,36 @@ class MHA(nn.Module):
else:
qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim)
q, k, v = (x.squeeze(2) for x in qkv.chunk(chunks=3, dim=2))
if inference_params is None:
if self.rotary_emb_dim > 0:
q = self.rotary_emb._single_eval_forward(q)
k = self.rotary_emb._single_eval_forward(k)
kv = torch.concat([k.unsqueeze(2), v.unsqueeze(2)], dim=2)
context = self.inner_cross_attn(q, kv)
kwargs["inference_params"] = inference_params
qkv = self.rotary_emb(qkv, **kwargs)
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
if qkv.dtype not in [torch.float16, torch.bfloat16]:
qkv = qkv.to(torch.bfloat16)
context = self.inner_attn(qkv).to(x.dtype)
else:
context = self.inner_attn(qkv)
else:
assert self.rotary_emb_dim > 0
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
q, k, v = (x.squeeze(2) for x in qkv.chunk(chunks=3, dim=2))
assert self.rotary_emb_dim > 0, "You should use rotary_emb."
if self.use_dynamic_ntk_rope:
kv = torch.stack([k, v], dim=2)
kv = _update_kv_cache(kv, inference_params, self.layer_idx)
if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None:
if inference_params.sequence_len_offset == 0:
empties = inference_params.attention_mask[..., -1].sum(dim=-1)
moved_q = q.clone()
moved_k = k.clone()
if inference_params.sequence_len_offset == 0:
for i in range(len(empties)):
if empties[i] != 0:
moved_q[i][: -empties[i]] = q[i][empties[i] :]
moved_k[i][: -empties[i]] = k[i][empties[i] :]
moved_q = self.rotary_emb._single_eval_forward(
moved_q, seqlen_offset=inference_params.sequence_len_offset
)
moved_k = self.rotary_emb._single_eval_forward(
moved_k, seqlen_offset=inference_params.sequence_len_offset
)
moved_q = self.rotary_emb._single_eval_forward(moved_q, seqlen_offset=0)
moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0)
for i in range(len(empties)):
if empties[i] != 0:
q[i][empties[i] :] = moved_q[i][: -empties[i]]
@ -178,7 +182,12 @@ class MHA(nn.Module):
else:
q[i] = moved_q[i]
k[i] = moved_k[i]
else:
elif not self.use_dynamic_ntk_rope:
if inference_params.sequence_len_offset > self.max_position_embeddings:
warnings.warn(
"Notice your prompt's length is longer than model's max_position_embeddings: "
f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations."
)
q = q.squeeze(1)
k = k.squeeze(1)
q = self.rotary_emb._single_forward(
@ -191,13 +200,30 @@ class MHA(nn.Module):
inference_params.sequence_len_offset * torch.ones(k.size(0), dtype=torch.int, device=k.device)
- empties,
).unsqueeze(1)
else:
q = q.squeeze(1)
q = self.rotary_emb._single_forward(
q,
inference_params.sequence_len_offset * torch.ones(q.size(0), dtype=torch.int, device=q.device)
- empties,
).unsqueeze(1)
empties = inference_params.attention_mask[..., -1].sum(dim=-1)
moved_k = k.clone()
for i in range(len(empties)):
if empties[i] != 0:
moved_k[i][: -empties[i]] = k[i][empties[i] :]
moved_k = self.rotary_emb._single_eval_forward(moved_k, seqlen_offset=0)
for i in range(len(empties)):
if empties[i] != 0:
k[i][empties[i] :] = moved_k[i][: -empties[i]]
else:
k[i] = moved_k[i]
else:
q = self.rotary_emb._single_forward(q, inference_params.sequence_len_offset)
k = self.rotary_emb._single_forward(k, inference_params.sequence_len_offset)
if not self.use_dynamic_ntk_rope:
kv = torch.stack([k, v], dim=2)
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
kv = _update_kv_cache(kv, inference_params, self.layer_idx)
if hasattr(inference_params, "attention_mask") and inference_params.attention_mask is not None:
@ -222,9 +248,16 @@ class MHA(nn.Module):
-1, kv.shape[-3], kv.shape[-2], kv.shape[-1]
)
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
if total_q.dtype not in [torch.float16, torch.bfloat16]:
total_q = total_q.to(torch.bfloat16)
if total_kv.dtype not in [torch.float16, torch.bfloat16]:
total_kv = total_kv.to(torch.bfloat16)
output = flash_attn_varlen_kvpacked_func(
total_q, total_kv, cu_seqlens, cu_seqlens, max_seqlen_q, max_seqlen_k, 0.0, None, True, False
)
).to(x.dtype)
context = torch.zeros_like(q)
context = context.masked_scatter_(attn_mask4flsh.view(bsz, -1, 1, 1), output)