From 0d3811c02995513ca49c0c5d4b45e59eed406ff3 Mon Sep 17 00:00:00 2001 From: Shuo Zhang Date: Thu, 23 Nov 2023 16:30:14 +0800 Subject: [PATCH] feat(model): add rope_base interface (#512) --- internlm/model/modeling_internlm.py | 9 +++++++++ internlm/model/multi_head_attention.py | 7 ++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 38fb9ca..204f71f 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -53,6 +53,7 @@ class PackedFlashBaseLayer1D(nn.Module): device (Optional[Union[str, torch.device]]): The device will be used. norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. use_flash_attn (bool): Whether use flash-attn. True by default. + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. """ def __init__( @@ -75,6 +76,7 @@ class PackedFlashBaseLayer1D(nn.Module): use_scaled_init: bool = True, use_swiglu: bool = True, use_flash_attn: bool = True, + rope_base: int = 10000, ): super().__init__() self.checkpoint = checkpoint @@ -98,6 +100,7 @@ class PackedFlashBaseLayer1D(nn.Module): rotary_emb_dim=head_dim, rotary_emb_scale_base=0, use_flash_attn=use_flash_attn, + rope_base=rope_base, device=device, dtype=dtype, ) @@ -264,6 +267,7 @@ class PackedFlashInternLm1D(nn.Module): residual_in_fp32 (bool): Whether to use residual in fp32. False by default. norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. use_flash_attn (bool): Whether to use flash-attn. True by default. + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. """ @@ -295,6 +299,7 @@ class PackedFlashInternLm1D(nn.Module): use_scaled_init: bool = True, use_swiglu: bool = True, use_flash_attn: bool = True, + rope_base: int = 10000, ): super().__init__() @@ -344,6 +349,7 @@ class PackedFlashInternLm1D(nn.Module): use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, use_flash_attn=use_flash_attn, + rope_base=rope_base, ) for lid in range(num_layers) ] @@ -490,6 +496,7 @@ def build_model_with_cfg( use_scaled_init: bool = True, use_swiglu: bool = True, use_flash_attn: bool = True, + rope_base: int = 10000, ): """ Build model with config. @@ -520,6 +527,7 @@ def build_model_with_cfg( use_scaled_init (bool): Whether to use scaled init. True by default. use_swiglu (bool): Whether to use swiglu. True by default. use_flash_attn (bool): Whether to use flash-attn. True by default. + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. """ @@ -545,6 +553,7 @@ def build_model_with_cfg( use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, use_flash_attn=use_flash_attn, + rope_base=rope_base, ) return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 6017dbc..e28db6a 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -63,6 +63,7 @@ class MHA(nn.Module): device (Optional[Union[str, torch.device]]): The device will be used. dtype (Optional[torch.dtype]): The type of data. use_flash_attn (bool): Whether to use flash-attn. True by default. + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. """ @@ -80,6 +81,7 @@ class MHA(nn.Module): rotary_emb_dim: int = 0, rotary_emb_scale_base: int = 0, use_flash_attn: bool = True, + rope_base: int = 10000, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ) -> None: @@ -100,13 +102,16 @@ class MHA(nn.Module): if self.use_dynamic_ntk_rope: self.rotary_emb = DynamicNTKScalingRotaryEmbedding( self.rotary_emb_dim, + base=rope_base, scale_base=rotary_emb_scale_base, device=device, max_position_embeddings=max_position_embeddings, scaling_factor=1.0, # Currently do not support dynamic scaling. ) else: - self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device) + self.rotary_emb = RotaryEmbedding( + self.rotary_emb_dim, base=rope_base, scale_base=rotary_emb_scale_base, device=device + ) # notice here should change bias=True self.Wqkv = ColumnParallelLinearTorch(