mirror of https://github.com/InternLM/InternLM
add rope doc
parent
f2d9b63545
commit
845cccd756
|
@ -105,6 +105,14 @@ class InternLMRMSNorm(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class InternLMRotaryEmbedding(torch.nn.Module):
|
class InternLMRotaryEmbedding(torch.nn.Module):
|
||||||
|
"""Implement InternLM's rotary embedding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): Characteristic dimension of each self-attentional head.
|
||||||
|
max_position_embeddings (int, optional): Model's training length. Defaults to 2048.
|
||||||
|
base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000.
|
||||||
|
device (Any, optional): Running device. Defaults to None.
|
||||||
|
"""
|
||||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
||||||
|
@ -137,7 +145,14 @@ class InternLMRotaryEmbedding(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
|
class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
|
||||||
"""Implement dynamic ntk rope.
|
"""Implement InternLM's DyanmicNTK extrapolation method, thereby broadening the model support context to 16K.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dim (int): Characteristic dimension of each self-attentional head.
|
||||||
|
max_position_embeddings (int, optional): Model's training length. Defaults to 2048.
|
||||||
|
base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000.
|
||||||
|
device (Any, optional): Running device. Defaults to None.
|
||||||
|
scaling_factor (float, optional): NTK method extrapolation coefficient. Defaults to 1.0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||||
|
|
Loading…
Reference in New Issue