mirror of https://github.com/InternLM/InternLM
feat(mdoel): add DynamicNTKScalingRotaryEmbedding (#339)
* add dynamic ntk rope * update dynamic ntk rope * fix lint check * fix lint check * add more desc --------- Co-authored-by: YWMditto <862779238@qq.com>pull/345/head
parent
67eda4cbe1
commit
8464425a7b
|
@ -137,6 +137,8 @@ class RotaryEmbedding(torch.nn.Module):
|
|||
""" """
|
||||
super().__init__()
|
||||
# Generate and save the inverse frequency buffer (non trainable)
|
||||
self.dim = dim
|
||||
self.base = base
|
||||
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
|
||||
self.scale_base = scale_base
|
||||
self.scale = (
|
||||
|
@ -230,3 +232,58 @@ class RotaryEmbedding(torch.nn.Module):
|
|||
assert self.scale is None
|
||||
self._update_cos_sin_cache(x, seqlen_offset + x.shape[1])
|
||||
return legacy_apply_rotary_embed(x, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
|
||||
|
||||
|
||||
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
||||
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla.
|
||||
|
||||
Reference implementation:
|
||||
https://github.com/huggingface/transformers/blob/eb8489971ac1415f67b0abdd1584fde8 \
|
||||
b659ced9/src/transformers/models/llama/modeling_llama.py#L147
|
||||
"""
|
||||
def __init__(self, dim: int, base=10000, scale_base=0, device=None, max_position_embeddings=2048, scaling_factor=1.0):
|
||||
super().__init__(dim=dim, base=base, scale_base=scale_base, device=device)
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.scaling_factor = scaling_factor
|
||||
|
||||
def _update(self, seqlen, x):
|
||||
self._seq_len_cached = seqlen
|
||||
if seqlen > self.max_position_embeddings:
|
||||
base = self.base * (
|
||||
(self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)
|
||||
) ** (self.dim / (self.dim - 2))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
|
||||
else:
|
||||
inv_freq = self.inv_freq
|
||||
|
||||
t = torch.arange(seqlen, device=x.device, dtype=inv_freq.dtype)
|
||||
freqs = torch.outer(t, inv_freq.to(device=t.device))
|
||||
if self.scale is None:
|
||||
self._cos_cached = torch.cos(freqs).to(x.dtype)
|
||||
self._sin_cached = torch.sin(freqs).to(x.dtype)
|
||||
else:
|
||||
power = (
|
||||
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
|
||||
) / self.scale_base
|
||||
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
|
||||
# We want the multiplication by scale to happen in fp32
|
||||
self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
|
||||
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
|
||||
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
|
||||
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
|
||||
|
||||
def _update_cos_sin_cache(self, x, indexes):
|
||||
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)"""
|
||||
if not isinstance(indexes, int):
|
||||
seqlen = indexes.max().item() + 1
|
||||
else:
|
||||
seqlen = indexes + 1 # eval_forward
|
||||
if seqlen <= self.max_position_embeddings:
|
||||
# Reset the tables if the sequence length has changed,
|
||||
# or if we're on a new device (possibly due to tracing for instance)
|
||||
if self._seq_len_cached > self.max_position_embeddings or seqlen > self._seq_len_cached \
|
||||
or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
|
||||
self._update(seqlen, x)
|
||||
else:
|
||||
self._update(seqlen, x)
|
||||
|
|
@ -59,10 +59,12 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
mlp_ratio: int = 4,
|
||||
attn_drop_rate: float = 0,
|
||||
drop_rate: float = 0.0,
|
||||
max_position_embeddings: int = 2048,
|
||||
dtype: torch.dtype = torch.float,
|
||||
layer_norm_epsilon: float = 1e-6,
|
||||
checkpoint: bool = False,
|
||||
layer_idx: int = 0,
|
||||
use_dynamic_ntk_rope: bool = False,
|
||||
residual_in_fp32: bool = False,
|
||||
device: Optional[torch.device] = None,
|
||||
norm_type: str = "rmsnorm",
|
||||
|
@ -84,9 +86,11 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
num_heads=num_attention_heads,
|
||||
process_group=gpc.get_group(ParallelMode.TENSOR),
|
||||
dropout=attn_drop_rate,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
softmax_scale=1 / math.sqrt(head_dim),
|
||||
causal=True,
|
||||
layer_idx=layer_idx,
|
||||
use_dynamic_ntk_rope=use_dynamic_ntk_rope,
|
||||
rotary_emb_dim=head_dim,
|
||||
rotary_emb_scale_base=0,
|
||||
use_flash_attn=use_flash_attn,
|
||||
|
@ -262,6 +266,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
mlp_ratio: int = 4.0,
|
||||
attn_drop_rate: float = 0.0,
|
||||
drop_rate: float = 0.0,
|
||||
max_position_embeddings: int = 2048,
|
||||
dtype: torch.dtype = torch.float,
|
||||
checkpoint: float = 0.0,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
|
@ -271,6 +276,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
embed_grad_scale: float = 0.1,
|
||||
parallel_output: bool = True,
|
||||
start_layer_idx: int = 0,
|
||||
use_dynamic_ntk_rope: bool = False,
|
||||
device: Optional[torch.device] = None,
|
||||
residual_in_fp32: bool = False,
|
||||
norm_type: str = "rmsnorm",
|
||||
|
@ -315,10 +321,12 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
mlp_ratio=mlp_ratio,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
drop_rate=drop_rate,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
dtype=dtype,
|
||||
layer_norm_epsilon=layer_norm_epsilon,
|
||||
checkpoint=lid < checkpoint_layer_num,
|
||||
layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation
|
||||
use_dynamic_ntk_rope=use_dynamic_ntk_rope,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
device=device,
|
||||
norm_type=norm_type,
|
||||
|
@ -443,8 +451,10 @@ def build_model_with_cfg(
|
|||
embed_grad_scale=1,
|
||||
parallel_output=True,
|
||||
num_attention_heads=32,
|
||||
max_position_embeddings=2048,
|
||||
mlp_ratio=4.0,
|
||||
residual_in_fp32=False,
|
||||
use_dynamic_ntk_rope=False,
|
||||
norm_type="rmsnorm",
|
||||
drop_rate=0,
|
||||
attn_drop_rate=0,
|
||||
|
@ -499,6 +509,8 @@ def build_model_with_cfg(
|
|||
parallel_output=parallel_output,
|
||||
mlp_ratio=mlp_ratio,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
use_dynamic_ntk_rope=use_dynamic_ntk_rope,
|
||||
norm_type=norm_type,
|
||||
drop_rate=drop_rate,
|
||||
attn_drop_rate=attn_drop_rate,
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
@ -16,7 +17,7 @@ from torch import nn
|
|||
|
||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.embedding import RotaryEmbedding
|
||||
from internlm.model.embedding import DynamicNTKScalingRotaryEmbedding, RotaryEmbedding
|
||||
from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch
|
||||
|
||||
|
||||
|
@ -52,10 +53,12 @@ class MHA(nn.Module):
|
|||
embed_dim: int,
|
||||
num_heads: int,
|
||||
process_group: Optional[torch.distributed.ProcessGroup],
|
||||
max_position_embeddings: int = 2048,
|
||||
dropout: float = 0.0,
|
||||
softmax_scale: float = None,
|
||||
causal: bool = False,
|
||||
layer_idx: int = None,
|
||||
use_dynamic_ntk_rope: bool = False,
|
||||
rotary_emb_dim: int = 0,
|
||||
rotary_emb_scale_base: int = 0,
|
||||
use_flash_attn: bool = True,
|
||||
|
@ -67,6 +70,8 @@ class MHA(nn.Module):
|
|||
self.embed_dim = embed_dim
|
||||
self.causal = causal
|
||||
self.layer_idx = layer_idx
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.use_dynamic_ntk_rope = use_dynamic_ntk_rope
|
||||
self.rotary_emb_dim = rotary_emb_dim
|
||||
self.use_flash_attn = use_flash_attn
|
||||
self.num_heads = num_heads
|
||||
|
@ -74,7 +79,16 @@ class MHA(nn.Module):
|
|||
self.head_dim = self.embed_dim // num_heads
|
||||
|
||||
if self.rotary_emb_dim > 0:
|
||||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
|
||||
if self.use_dynamic_ntk_rope:
|
||||
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||
self.rotary_emb_dim,
|
||||
scale_base=rotary_emb_scale_base,
|
||||
device=device,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
scaling_factor=1.0, # Currently do not support dynamic scaling.
|
||||
)
|
||||
else:
|
||||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
|
||||
|
||||
# notice here should change bias=True
|
||||
self.Wqkv = ColumnParallelLinearTorch(
|
||||
|
@ -127,11 +141,10 @@ class MHA(nn.Module):
|
|||
else:
|
||||
qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim)
|
||||
|
||||
if self.rotary_emb_dim > 0:
|
||||
kwargs["inference_params"] = inference_params
|
||||
qkv = self.rotary_emb(qkv, **kwargs)
|
||||
|
||||
if inference_params is None:
|
||||
if self.rotary_emb_dim > 0:
|
||||
kwargs["inference_params"] = inference_params
|
||||
qkv = self.rotary_emb(qkv, **kwargs)
|
||||
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
|
||||
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
||||
if qkv.dtype not in [torch.float16, torch.bfloat16]:
|
||||
|
@ -140,9 +153,39 @@ class MHA(nn.Module):
|
|||
else:
|
||||
context = self.inner_attn(qkv)
|
||||
else:
|
||||
q = qkv[:, :, 0]
|
||||
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
||||
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
|
||||
if self.use_dynamic_ntk_rope:
|
||||
q = qkv[:, :, 0]
|
||||
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
||||
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
|
||||
if inference_params.sequence_len_offset != 0:
|
||||
# q shape: [bsz, 1, nheads, head_dim]
|
||||
# kv shape: [bsz, seqlen, 2, nheads, head_dim]
|
||||
bsz, seq_len, _, nheads, head_dim = kv.shape
|
||||
q = torch.cat([q.new_zeros(size=(bsz, seq_len - 1, nheads, head_dim)), q], dim=1).unsqueeze(2)
|
||||
qkv = torch.cat([q, kv], dim=2)
|
||||
if self.rotary_emb_dim > 0:
|
||||
qkv = self.rotary_emb(qkv)
|
||||
q = qkv[:, [-1], 0]
|
||||
kv = qkv[:, :, 1:]
|
||||
else:
|
||||
if inference_params.sequence_len_offset > self.max_position_embeddings:
|
||||
warnings.warn(
|
||||
"Notice your prompt's length is longer than model's max_position_embeddings: "
|
||||
f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations."
|
||||
)
|
||||
if self.rotary_emb_dim > 0:
|
||||
kwargs["inference_params"] = inference_params
|
||||
qkv = self.rotary_emb(qkv, **kwargs)
|
||||
q = qkv[:, :, 0]
|
||||
kv = qkv[:, :, 1:]
|
||||
else:
|
||||
if self.rotary_emb_dim > 0:
|
||||
kwargs["inference_params"] = inference_params
|
||||
qkv = self.rotary_emb(qkv, **kwargs)
|
||||
q = qkv[:, :, 0]
|
||||
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
||||
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
|
||||
|
||||
# If we're processing the prompt, causal=None (use self.causal).
|
||||
# If we're decoding, then causal=False.
|
||||
causal = None if inference_params.sequence_len_offset == 0 else False
|
||||
|
|
Loading…
Reference in New Issue