feat(model): add rope_base interface (#512)

pull/514/head
Shuo Zhang 2023-11-23 16:30:14 +08:00 committed by GitHub
parent 7776693373
commit 0d3811c029
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 1 deletions

View File

@ -53,6 +53,7 @@ class PackedFlashBaseLayer1D(nn.Module):
device (Optional[Union[str, torch.device]]): The device will be used.
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
use_flash_attn (bool): Whether use flash-attn. True by default.
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
"""
def __init__(
@ -75,6 +76,7 @@ class PackedFlashBaseLayer1D(nn.Module):
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
rope_base: int = 10000,
):
super().__init__()
self.checkpoint = checkpoint
@ -98,6 +100,7 @@ class PackedFlashBaseLayer1D(nn.Module):
rotary_emb_dim=head_dim,
rotary_emb_scale_base=0,
use_flash_attn=use_flash_attn,
rope_base=rope_base,
device=device,
dtype=dtype,
)
@ -264,6 +267,7 @@ class PackedFlashInternLm1D(nn.Module):
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
"""
@ -295,6 +299,7 @@ class PackedFlashInternLm1D(nn.Module):
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
rope_base: int = 10000,
):
super().__init__()
@ -344,6 +349,7 @@ class PackedFlashInternLm1D(nn.Module):
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
rope_base=rope_base,
)
for lid in range(num_layers)
]
@ -490,6 +496,7 @@ def build_model_with_cfg(
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
rope_base: int = 10000,
):
"""
Build model with config.
@ -520,6 +527,7 @@ def build_model_with_cfg(
use_scaled_init (bool): Whether to use scaled init. True by default.
use_swiglu (bool): Whether to use swiglu. True by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
"""
@ -545,6 +553,7 @@ def build_model_with_cfg(
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
rope_base=rope_base,
)
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)

View File

@ -63,6 +63,7 @@ class MHA(nn.Module):
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
use_flash_attn (bool): Whether to use flash-attn. True by default.
rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
"""
@ -80,6 +81,7 @@ class MHA(nn.Module):
rotary_emb_dim: int = 0,
rotary_emb_scale_base: int = 0,
use_flash_attn: bool = True,
rope_base: int = 10000,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
@ -100,13 +102,16 @@ class MHA(nn.Module):
if self.use_dynamic_ntk_rope:
self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
self.rotary_emb_dim,
base=rope_base,
scale_base=rotary_emb_scale_base,
device=device,
max_position_embeddings=max_position_embeddings,
scaling_factor=1.0, # Currently do not support dynamic scaling.
)
else:
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
self.rotary_emb = RotaryEmbedding(
self.rotary_emb_dim, base=rope_base, scale_base=rotary_emb_scale_base, device=device
)
# notice here should change bias=True
self.Wqkv = ColumnParallelLinearTorch(