mirror of https://github.com/InternLM/InternLM
				
				
				
			feat(model): add rope_base interface (#512)
							parent
							
								
									7776693373
								
							
						
					
					
						commit
						0d3811c029
					
				| 
						 | 
				
			
			@ -53,6 +53,7 @@ class PackedFlashBaseLayer1D(nn.Module):
 | 
			
		|||
        device (Optional[Union[str, torch.device]]): The device will be used.
 | 
			
		||||
        norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
 | 
			
		||||
        use_flash_attn (bool): Whether use flash-attn. True by default.
 | 
			
		||||
        rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
| 
						 | 
				
			
			@ -75,6 +76,7 @@ class PackedFlashBaseLayer1D(nn.Module):
 | 
			
		|||
        use_scaled_init: bool = True,
 | 
			
		||||
        use_swiglu: bool = True,
 | 
			
		||||
        use_flash_attn: bool = True,
 | 
			
		||||
        rope_base: int = 10000,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self.checkpoint = checkpoint
 | 
			
		||||
| 
						 | 
				
			
			@ -98,6 +100,7 @@ class PackedFlashBaseLayer1D(nn.Module):
 | 
			
		|||
            rotary_emb_dim=head_dim,
 | 
			
		||||
            rotary_emb_scale_base=0,
 | 
			
		||||
            use_flash_attn=use_flash_attn,
 | 
			
		||||
            rope_base=rope_base,
 | 
			
		||||
            device=device,
 | 
			
		||||
            dtype=dtype,
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			@ -264,6 +267,7 @@ class PackedFlashInternLm1D(nn.Module):
 | 
			
		|||
        residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
 | 
			
		||||
        norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
 | 
			
		||||
        use_flash_attn (bool): Whether to use flash-attn. True by default.
 | 
			
		||||
        rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -295,6 +299,7 @@ class PackedFlashInternLm1D(nn.Module):
 | 
			
		|||
        use_scaled_init: bool = True,
 | 
			
		||||
        use_swiglu: bool = True,
 | 
			
		||||
        use_flash_attn: bool = True,
 | 
			
		||||
        rope_base: int = 10000,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -344,6 +349,7 @@ class PackedFlashInternLm1D(nn.Module):
 | 
			
		|||
                    use_scaled_init=use_scaled_init,
 | 
			
		||||
                    use_swiglu=use_swiglu,
 | 
			
		||||
                    use_flash_attn=use_flash_attn,
 | 
			
		||||
                    rope_base=rope_base,
 | 
			
		||||
                )
 | 
			
		||||
                for lid in range(num_layers)
 | 
			
		||||
            ]
 | 
			
		||||
| 
						 | 
				
			
			@ -490,6 +496,7 @@ def build_model_with_cfg(
 | 
			
		|||
    use_scaled_init: bool = True,
 | 
			
		||||
    use_swiglu: bool = True,
 | 
			
		||||
    use_flash_attn: bool = True,
 | 
			
		||||
    rope_base: int = 10000,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Build model with config.
 | 
			
		||||
| 
						 | 
				
			
			@ -520,6 +527,7 @@ def build_model_with_cfg(
 | 
			
		|||
        use_scaled_init (bool): Whether to use scaled init. True by default.
 | 
			
		||||
        use_swiglu (bool): Whether to use swiglu. True by default.
 | 
			
		||||
        use_flash_attn (bool): Whether to use flash-attn. True by default.
 | 
			
		||||
        rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -545,6 +553,7 @@ def build_model_with_cfg(
 | 
			
		|||
        use_scaled_init=use_scaled_init,
 | 
			
		||||
        use_swiglu=use_swiglu,
 | 
			
		||||
        use_flash_attn=use_flash_attn,
 | 
			
		||||
        rope_base=rope_base,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -63,6 +63,7 @@ class MHA(nn.Module):
 | 
			
		|||
        device (Optional[Union[str, torch.device]]): The device will be used.
 | 
			
		||||
        dtype (Optional[torch.dtype]): The type of data.
 | 
			
		||||
        use_flash_attn (bool): Whether to use flash-attn. True by default.
 | 
			
		||||
        rope_base (int): The value of `base` for rotary position embeddings. 10000 by default.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -80,6 +81,7 @@ class MHA(nn.Module):
 | 
			
		|||
        rotary_emb_dim: int = 0,
 | 
			
		||||
        rotary_emb_scale_base: int = 0,
 | 
			
		||||
        use_flash_attn: bool = True,
 | 
			
		||||
        rope_base: int = 10000,
 | 
			
		||||
        device: Optional[torch.device] = None,
 | 
			
		||||
        dtype: Optional[torch.dtype] = None,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
| 
						 | 
				
			
			@ -100,13 +102,16 @@ class MHA(nn.Module):
 | 
			
		|||
            if self.use_dynamic_ntk_rope:
 | 
			
		||||
                self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
 | 
			
		||||
                    self.rotary_emb_dim,
 | 
			
		||||
                    base=rope_base,
 | 
			
		||||
                    scale_base=rotary_emb_scale_base,
 | 
			
		||||
                    device=device,
 | 
			
		||||
                    max_position_embeddings=max_position_embeddings,
 | 
			
		||||
                    scaling_factor=1.0,  # Currently do not support dynamic scaling.
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, device=device)
 | 
			
		||||
                self.rotary_emb = RotaryEmbedding(
 | 
			
		||||
                    self.rotary_emb_dim, base=rope_base, scale_base=rotary_emb_scale_base, device=device
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        # notice here should change bias=True
 | 
			
		||||
        self.Wqkv = ColumnParallelLinearTorch(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue