mirror of https://github.com/InternLM/InternLM
feat(model): add rope_base interface (#512)
parent
7776693373
commit
0d3811c029
|
@ -53,6 +53,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
|
||||
use_flash_attn (bool): Whether use flash-attn. True by default.
|
||||
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -75,6 +76,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
rope_base: int = 10000,
|
||||
):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
|
@ -98,6 +100,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
rotary_emb_dim=head_dim,
|
||||
rotary_emb_scale_base=0,
|
||||
use_flash_attn=use_flash_attn,
|
||||
rope_base=rope_base,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
@ -264,6 +267,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
|
||||
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
|
||||
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -295,6 +299,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
rope_base: int = 10000,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -344,6 +349,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
use_scaled_init=use_scaled_init,
|
||||
use_swiglu=use_swiglu,
|
||||
use_flash_attn=use_flash_attn,
|
||||
rope_base=rope_base,
|
||||
)
|
||||
for lid in range(num_layers)
|
||||
]
|
||||
|
@ -490,6 +496,7 @@ def build_model_with_cfg(
|
|||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
rope_base: int = 10000,
|
||||
):
|
||||
"""
|
||||
Build model with config.
|
||||
|
@ -520,6 +527,7 @@ def build_model_with_cfg(
|
|||
use_scaled_init (bool): Whether to use scaled init. True by default.
|
||||
use_swiglu (bool): Whether to use swiglu. True by default.
|
||||
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -545,6 +553,7 @@ def build_model_with_cfg(
|
|||
use_scaled_init=use_scaled_init,
|
||||
use_swiglu=use_swiglu,
|
||||
use_flash_attn=use_flash_attn,
|
||||
rope_base=rope_base,
|
||||
)
|
||||
|
||||
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
|
||||
|
|
|
@ -63,6 +63,7 @@ class MHA(nn.Module):
|
|||
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||
dtype (Optional[torch.dtype]): The type of data.
|
||||
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -80,6 +81,7 @@ class MHA(nn.Module):
|
|||
rotary_emb_dim: int = 0,
|
||||
rotary_emb_scale_base: int = 0,
|
||||
use_flash_attn: bool = True,
|
||||
rope_base: int = 10000,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
|
@ -100,13 +102,16 @@ class MHA(nn.Module):
|
|||
if self.use_dynamic_ntk_rope:
|
||||
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
||||
self.rotary_emb_dim,
|
||||
base=rope_base,
|
||||
scale_base=rotary_emb_scale_base,
|
||||
device=device,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
scaling_factor=1.0, # Currently do not support dynamic scaling.
|
||||
)
|
||||
else:
|
||||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
self.rotary_emb_dim, base=rope_base, scale_base=rotary_emb_scale_base, device=device
|
||||
)
|
||||
|
||||
# notice here should change bias=True
|
||||
self.Wqkv = ColumnParallelLinearTorch(
|
||||
|
|
Loading…
Reference in New Issue