feat(mdoel): add DynamicNTKScalingRotaryEmbedding (#339)

* add dynamic ntk rope

* update dynamic ntk rope

* fix lint check

* fix lint check

* add more desc

---------

Co-authored-by: YWMditto <862779238@qq.com>
pull/345/head
YWMditto 2023-09-20 23:31:47 +08:00 committed by GitHub
parent 67eda4cbe1
commit 8464425a7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 121 additions and 9 deletions

View File

@ -137,6 +137,8 @@ class RotaryEmbedding(torch.nn.Module):
""" """ """ """
super().__init__() super().__init__()
# Generate and save the inverse frequency buffer (non trainable) # Generate and save the inverse frequency buffer (non trainable)
self.dim = dim
self.base = base
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
self.scale_base = scale_base self.scale_base = scale_base
self.scale = ( self.scale = (
@ -230,3 +232,58 @@ class RotaryEmbedding(torch.nn.Module):
assert self.scale is None assert self.scale is None
self._update_cos_sin_cache(x, seqlen_offset + x.shape[1]) self._update_cos_sin_cache(x, seqlen_offset + x.shape[1])
return legacy_apply_rotary_embed(x, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]) return legacy_apply_rotary_embed(x, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:])
class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla.
Reference implementation:
https://github.com/huggingface/transformers/blob/eb8489971ac1415f67b0abdd1584fde8 \
b659ced9/src/transformers/models/llama/modeling_llama.py#L147
"""
def __init__(self, dim: int, base=10000, scale_base=0, device=None, max_position_embeddings=2048, scaling_factor=1.0):
super().__init__(dim=dim, base=base, scale_base=scale_base, device=device)
self.max_position_embeddings = max_position_embeddings
self.scaling_factor = scaling_factor
def _update(self, seqlen, x):
self._seq_len_cached = seqlen
if seqlen > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
else:
inv_freq = self.inv_freq
t = torch.arange(seqlen, device=x.device, dtype=inv_freq.dtype)
freqs = torch.outer(t, inv_freq.to(device=t.device))
if self.scale is None:
self._cos_cached = torch.cos(freqs).to(x.dtype)
self._sin_cached = torch.sin(freqs).to(x.dtype)
else:
power = (
torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2
) / self.scale_base
scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1")
# We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
def _update_cos_sin_cache(self, x, indexes):
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)"""
if not isinstance(indexes, int):
seqlen = indexes.max().item() + 1
else:
seqlen = indexes + 1 # eval_forward
if seqlen <= self.max_position_embeddings:
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
if self._seq_len_cached > self.max_position_embeddings or seqlen > self._seq_len_cached \
or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype:
self._update(seqlen, x)
else:
self._update(seqlen, x)

View File

@ -59,10 +59,12 @@ class PackedFlashBaseLayer1D(nn.Module):
mlp_ratio: int = 4, mlp_ratio: int = 4,
attn_drop_rate: float = 0, attn_drop_rate: float = 0,
drop_rate: float = 0.0, drop_rate: float = 0.0,
max_position_embeddings: int = 2048,
dtype: torch.dtype = torch.float, dtype: torch.dtype = torch.float,
layer_norm_epsilon: float = 1e-6, layer_norm_epsilon: float = 1e-6,
checkpoint: bool = False, checkpoint: bool = False,
layer_idx: int = 0, layer_idx: int = 0,
use_dynamic_ntk_rope: bool = False,
residual_in_fp32: bool = False, residual_in_fp32: bool = False,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
norm_type: str = "rmsnorm", norm_type: str = "rmsnorm",
@ -84,9 +86,11 @@ class PackedFlashBaseLayer1D(nn.Module):
num_heads=num_attention_heads, num_heads=num_attention_heads,
process_group=gpc.get_group(ParallelMode.TENSOR), process_group=gpc.get_group(ParallelMode.TENSOR),
dropout=attn_drop_rate, dropout=attn_drop_rate,
max_position_embeddings=max_position_embeddings,
softmax_scale=1 / math.sqrt(head_dim), softmax_scale=1 / math.sqrt(head_dim),
causal=True, causal=True,
layer_idx=layer_idx, layer_idx=layer_idx,
use_dynamic_ntk_rope=use_dynamic_ntk_rope,
rotary_emb_dim=head_dim, rotary_emb_dim=head_dim,
rotary_emb_scale_base=0, rotary_emb_scale_base=0,
use_flash_attn=use_flash_attn, use_flash_attn=use_flash_attn,
@ -262,6 +266,7 @@ class PackedFlashInternLm1D(nn.Module):
mlp_ratio: int = 4.0, mlp_ratio: int = 4.0,
attn_drop_rate: float = 0.0, attn_drop_rate: float = 0.0,
drop_rate: float = 0.0, drop_rate: float = 0.0,
max_position_embeddings: int = 2048,
dtype: torch.dtype = torch.float, dtype: torch.dtype = torch.float,
checkpoint: float = 0.0, checkpoint: float = 0.0,
layer_norm_epsilon: float = 1e-5, layer_norm_epsilon: float = 1e-5,
@ -271,6 +276,7 @@ class PackedFlashInternLm1D(nn.Module):
embed_grad_scale: float = 0.1, embed_grad_scale: float = 0.1,
parallel_output: bool = True, parallel_output: bool = True,
start_layer_idx: int = 0, start_layer_idx: int = 0,
use_dynamic_ntk_rope: bool = False,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
residual_in_fp32: bool = False, residual_in_fp32: bool = False,
norm_type: str = "rmsnorm", norm_type: str = "rmsnorm",
@ -315,10 +321,12 @@ class PackedFlashInternLm1D(nn.Module):
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
attn_drop_rate=attn_drop_rate, attn_drop_rate=attn_drop_rate,
drop_rate=drop_rate, drop_rate=drop_rate,
max_position_embeddings=max_position_embeddings,
dtype=dtype, dtype=dtype,
layer_norm_epsilon=layer_norm_epsilon, layer_norm_epsilon=layer_norm_epsilon,
checkpoint=lid < checkpoint_layer_num, checkpoint=lid < checkpoint_layer_num,
layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation
use_dynamic_ntk_rope=use_dynamic_ntk_rope,
residual_in_fp32=residual_in_fp32, residual_in_fp32=residual_in_fp32,
device=device, device=device,
norm_type=norm_type, norm_type=norm_type,
@ -443,8 +451,10 @@ def build_model_with_cfg(
embed_grad_scale=1, embed_grad_scale=1,
parallel_output=True, parallel_output=True,
num_attention_heads=32, num_attention_heads=32,
max_position_embeddings=2048,
mlp_ratio=4.0, mlp_ratio=4.0,
residual_in_fp32=False, residual_in_fp32=False,
use_dynamic_ntk_rope=False,
norm_type="rmsnorm", norm_type="rmsnorm",
drop_rate=0, drop_rate=0,
attn_drop_rate=0, attn_drop_rate=0,
@ -499,6 +509,8 @@ def build_model_with_cfg(
parallel_output=parallel_output, parallel_output=parallel_output,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
residual_in_fp32=residual_in_fp32, residual_in_fp32=residual_in_fp32,
max_position_embeddings=max_position_embeddings,
use_dynamic_ntk_rope=use_dynamic_ntk_rope,
norm_type=norm_type, norm_type=norm_type,
drop_rate=drop_rate, drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate, attn_drop_rate=attn_drop_rate,

View File

@ -1,6 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import warnings
from typing import Optional from typing import Optional
import torch import torch
@ -16,7 +17,7 @@ from torch import nn
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.model.embedding import RotaryEmbedding from internlm.model.embedding import DynamicNTKScalingRotaryEmbedding, RotaryEmbedding
from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch
@ -52,10 +53,12 @@ class MHA(nn.Module):
embed_dim: int, embed_dim: int,
num_heads: int, num_heads: int,
process_group: Optional[torch.distributed.ProcessGroup], process_group: Optional[torch.distributed.ProcessGroup],
max_position_embeddings: int = 2048,
dropout: float = 0.0, dropout: float = 0.0,
softmax_scale: float = None, softmax_scale: float = None,
causal: bool = False, causal: bool = False,
layer_idx: int = None, layer_idx: int = None,
use_dynamic_ntk_rope: bool = False,
rotary_emb_dim: int = 0, rotary_emb_dim: int = 0,
rotary_emb_scale_base: int = 0, rotary_emb_scale_base: int = 0,
use_flash_attn: bool = True, use_flash_attn: bool = True,
@ -67,6 +70,8 @@ class MHA(nn.Module):
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.causal = causal self.causal = causal
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.max_position_embeddings = max_position_embeddings
self.use_dynamic_ntk_rope = use_dynamic_ntk_rope
self.rotary_emb_dim = rotary_emb_dim self.rotary_emb_dim = rotary_emb_dim
self.use_flash_attn = use_flash_attn self.use_flash_attn = use_flash_attn
self.num_heads = num_heads self.num_heads = num_heads
@ -74,7 +79,16 @@ class MHA(nn.Module):
self.head_dim = self.embed_dim // num_heads self.head_dim = self.embed_dim // num_heads
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device) if self.use_dynamic_ntk_rope:
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
self.rotary_emb_dim,
scale_base=rotary_emb_scale_base,
device=device,
max_position_embeddings=max_position_embeddings,
scaling_factor=1.0, # Currently do not support dynamic scaling.
)
else:
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
# notice here should change bias=True # notice here should change bias=True
self.Wqkv = ColumnParallelLinearTorch( self.Wqkv = ColumnParallelLinearTorch(
@ -127,11 +141,10 @@ class MHA(nn.Module):
else: else:
qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim) qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim)
if self.rotary_emb_dim > 0:
kwargs["inference_params"] = inference_params
qkv = self.rotary_emb(qkv, **kwargs)
if inference_params is None: if inference_params is None:
if self.rotary_emb_dim > 0:
kwargs["inference_params"] = inference_params
qkv = self.rotary_emb(qkv, **kwargs)
if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn:
with torch.cuda.amp.autocast(dtype=torch.bfloat16): with torch.cuda.amp.autocast(dtype=torch.bfloat16):
if qkv.dtype not in [torch.float16, torch.bfloat16]: if qkv.dtype not in [torch.float16, torch.bfloat16]:
@ -140,9 +153,39 @@ class MHA(nn.Module):
else: else:
context = self.inner_attn(qkv) context = self.inner_attn(qkv)
else: else:
q = qkv[:, :, 0] if self.use_dynamic_ntk_rope:
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" q = qkv[:, :, 0]
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx) assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
if inference_params.sequence_len_offset != 0:
# q shape: [bsz, 1, nheads, head_dim]
# kv shape: [bsz, seqlen, 2, nheads, head_dim]
bsz, seq_len, _, nheads, head_dim = kv.shape
q = torch.cat([q.new_zeros(size=(bsz, seq_len - 1, nheads, head_dim)), q], dim=1).unsqueeze(2)
qkv = torch.cat([q, kv], dim=2)
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv)
q = qkv[:, [-1], 0]
kv = qkv[:, :, 1:]
else:
if inference_params.sequence_len_offset > self.max_position_embeddings:
warnings.warn(
"Notice your prompt's length is longer than model's max_position_embeddings: "
f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations."
)
if self.rotary_emb_dim > 0:
kwargs["inference_params"] = inference_params
qkv = self.rotary_emb(qkv, **kwargs)
q = qkv[:, :, 0]
kv = qkv[:, :, 1:]
else:
if self.rotary_emb_dim > 0:
kwargs["inference_params"] = inference_params
qkv = self.rotary_emb(qkv, **kwargs)
q = qkv[:, :, 0]
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx)
# If we're processing the prompt, causal=None (use self.causal). # If we're processing the prompt, causal=None (use self.causal).
# If we're decoding, then causal=False. # If we're decoding, then causal=False.
causal = None if inference_params.sequence_len_offset == 0 else False causal = None if inference_params.sequence_len_offset == 0 else False