diff --git a/internlm/model/embedding.py b/internlm/model/embedding.py index d4ae9b5..01a9d56 100644 --- a/internlm/model/embedding.py +++ b/internlm/model/embedding.py @@ -137,6 +137,8 @@ class RotaryEmbedding(torch.nn.Module): """ """ super().__init__() # Generate and save the inverse frequency buffer (non trainable) + self.dim = dim + self.base = base self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) self.scale_base = scale_base self.scale = ( @@ -230,3 +232,58 @@ class RotaryEmbedding(torch.nn.Module): assert self.scale is None self._update_cos_sin_cache(x, seqlen_offset + x.shape[1]) return legacy_apply_rotary_embed(x, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]) + + +class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla. + + Reference implementation: + https://github.com/huggingface/transformers/blob/eb8489971ac1415f67b0abdd1584fde8 \ + b659ced9/src/transformers/models/llama/modeling_llama.py#L147 + """ + def __init__(self, dim: int, base=10000, scale_base=0, device=None, max_position_embeddings=2048, scaling_factor=1.0): + super().__init__(dim=dim, base=base, scale_base=scale_base, device=device) + self.max_position_embeddings = max_position_embeddings + self.scaling_factor = scaling_factor + + def _update(self, seqlen, x): + self._seq_len_cached = seqlen + if seqlen > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim)) + else: + inv_freq = self.inv_freq + + t = torch.arange(seqlen, device=x.device, dtype=inv_freq.dtype) + freqs = torch.outer(t, inv_freq.to(device=t.device)) + if self.scale is None: + self._cos_cached = torch.cos(freqs).to(x.dtype) + self._sin_cached = torch.sin(freqs).to(x.dtype) + else: + power = ( + torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2 + ) / self.scale_base + scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") + # We want the multiplication by scale to happen in fp32 + self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype) + + def _update_cos_sin_cache(self, x, indexes): + """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)""" + if not isinstance(indexes, int): + seqlen = indexes.max().item() + 1 + else: + seqlen = indexes + 1 # eval_forward + if seqlen <= self.max_position_embeddings: + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + if self._seq_len_cached > self.max_position_embeddings or seqlen > self._seq_len_cached \ + or self._cos_cached.device != x.device or self._cos_cached.dtype != x.dtype: + self._update(seqlen, x) + else: + self._update(seqlen, x) + \ No newline at end of file diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 651a629..2856a78 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -59,10 +59,12 @@ class PackedFlashBaseLayer1D(nn.Module): mlp_ratio: int = 4, attn_drop_rate: float = 0, drop_rate: float = 0.0, + max_position_embeddings: int = 2048, dtype: torch.dtype = torch.float, layer_norm_epsilon: float = 1e-6, checkpoint: bool = False, layer_idx: int = 0, + use_dynamic_ntk_rope: bool = False, residual_in_fp32: bool = False, device: Optional[torch.device] = None, norm_type: str = "rmsnorm", @@ -84,9 +86,11 @@ class PackedFlashBaseLayer1D(nn.Module): num_heads=num_attention_heads, process_group=gpc.get_group(ParallelMode.TENSOR), dropout=attn_drop_rate, + max_position_embeddings=max_position_embeddings, softmax_scale=1 / math.sqrt(head_dim), causal=True, layer_idx=layer_idx, + use_dynamic_ntk_rope=use_dynamic_ntk_rope, rotary_emb_dim=head_dim, rotary_emb_scale_base=0, use_flash_attn=use_flash_attn, @@ -262,6 +266,7 @@ class PackedFlashInternLm1D(nn.Module): mlp_ratio: int = 4.0, attn_drop_rate: float = 0.0, drop_rate: float = 0.0, + max_position_embeddings: int = 2048, dtype: torch.dtype = torch.float, checkpoint: float = 0.0, layer_norm_epsilon: float = 1e-5, @@ -271,6 +276,7 @@ class PackedFlashInternLm1D(nn.Module): embed_grad_scale: float = 0.1, parallel_output: bool = True, start_layer_idx: int = 0, + use_dynamic_ntk_rope: bool = False, device: Optional[torch.device] = None, residual_in_fp32: bool = False, norm_type: str = "rmsnorm", @@ -315,10 +321,12 @@ class PackedFlashInternLm1D(nn.Module): mlp_ratio=mlp_ratio, attn_drop_rate=attn_drop_rate, drop_rate=drop_rate, + max_position_embeddings=max_position_embeddings, dtype=dtype, layer_norm_epsilon=layer_norm_epsilon, checkpoint=lid < checkpoint_layer_num, layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation + use_dynamic_ntk_rope=use_dynamic_ntk_rope, residual_in_fp32=residual_in_fp32, device=device, norm_type=norm_type, @@ -443,8 +451,10 @@ def build_model_with_cfg( embed_grad_scale=1, parallel_output=True, num_attention_heads=32, + max_position_embeddings=2048, mlp_ratio=4.0, residual_in_fp32=False, + use_dynamic_ntk_rope=False, norm_type="rmsnorm", drop_rate=0, attn_drop_rate=0, @@ -499,6 +509,8 @@ def build_model_with_cfg( parallel_output=parallel_output, mlp_ratio=mlp_ratio, residual_in_fp32=residual_in_fp32, + max_position_embeddings=max_position_embeddings, + use_dynamic_ntk_rope=use_dynamic_ntk_rope, norm_type=norm_type, drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index d634605..e4008e1 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import warnings from typing import Optional import torch @@ -16,7 +17,7 @@ from torch import nn from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode from internlm.core.context import global_context as gpc -from internlm.model.embedding import RotaryEmbedding +from internlm.model.embedding import DynamicNTKScalingRotaryEmbedding, RotaryEmbedding from internlm.model.linear import ColumnParallelLinearTorch, RowParallelLinearTorch @@ -52,10 +53,12 @@ class MHA(nn.Module): embed_dim: int, num_heads: int, process_group: Optional[torch.distributed.ProcessGroup], + max_position_embeddings: int = 2048, dropout: float = 0.0, softmax_scale: float = None, causal: bool = False, layer_idx: int = None, + use_dynamic_ntk_rope: bool = False, rotary_emb_dim: int = 0, rotary_emb_scale_base: int = 0, use_flash_attn: bool = True, @@ -67,6 +70,8 @@ class MHA(nn.Module): self.embed_dim = embed_dim self.causal = causal self.layer_idx = layer_idx + self.max_position_embeddings = max_position_embeddings + self.use_dynamic_ntk_rope = use_dynamic_ntk_rope self.rotary_emb_dim = rotary_emb_dim self.use_flash_attn = use_flash_attn self.num_heads = num_heads @@ -74,7 +79,16 @@ class MHA(nn.Module): self.head_dim = self.embed_dim // num_heads if self.rotary_emb_dim > 0: - self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device) + if self.use_dynamic_ntk_rope: + self.rotary_emb = DynamicNTKScalingRotaryEmbedding( + self.rotary_emb_dim, + scale_base=rotary_emb_scale_base, + device=device, + max_position_embeddings=max_position_embeddings, + scaling_factor=1.0, # Currently do not support dynamic scaling. + ) + else: + self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device) # notice here should change bias=True self.Wqkv = ColumnParallelLinearTorch( @@ -127,11 +141,10 @@ class MHA(nn.Module): else: qkv = rearrange(qkv, "(b s) (three h d) -> b s three h d", s=seqlen, three=3, d=self.head_dim) - if self.rotary_emb_dim > 0: - kwargs["inference_params"] = inference_params - qkv = self.rotary_emb(qkv, **kwargs) - if inference_params is None: + if self.rotary_emb_dim > 0: + kwargs["inference_params"] = inference_params + qkv = self.rotary_emb(qkv, **kwargs) if gpc.config.model.dtype is torch.float32 and gpc.config.model.use_flash_attn: with torch.cuda.amp.autocast(dtype=torch.bfloat16): if qkv.dtype not in [torch.float16, torch.bfloat16]: @@ -140,9 +153,39 @@ class MHA(nn.Module): else: context = self.inner_attn(qkv) else: - q = qkv[:, :, 0] - assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" - kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx) + if self.use_dynamic_ntk_rope: + q = qkv[:, :, 0] + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" + kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx) + if inference_params.sequence_len_offset != 0: + # q shape: [bsz, 1, nheads, head_dim] + # kv shape: [bsz, seqlen, 2, nheads, head_dim] + bsz, seq_len, _, nheads, head_dim = kv.shape + q = torch.cat([q.new_zeros(size=(bsz, seq_len - 1, nheads, head_dim)), q], dim=1).unsqueeze(2) + qkv = torch.cat([q, kv], dim=2) + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb(qkv) + q = qkv[:, [-1], 0] + kv = qkv[:, :, 1:] + else: + if inference_params.sequence_len_offset > self.max_position_embeddings: + warnings.warn( + "Notice your prompt's length is longer than model's max_position_embeddings: " + f"{self.max_position_embeddings}, which will cause deviations in dynamic ntk calculations." + ) + if self.rotary_emb_dim > 0: + kwargs["inference_params"] = inference_params + qkv = self.rotary_emb(qkv, **kwargs) + q = qkv[:, :, 0] + kv = qkv[:, :, 1:] + else: + if self.rotary_emb_dim > 0: + kwargs["inference_params"] = inference_params + qkv = self.rotary_emb(qkv, **kwargs) + q = qkv[:, :, 0] + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" + kv = _update_kv_cache(qkv[:, :, 1:], inference_params, self.layer_idx) + # If we're processing the prompt, causal=None (use self.causal). # If we're decoding, then causal=False. causal = None if inference_params.sequence_len_offset == 0 else False