mirror of https://github.com/InternLM/InternLM
update rope sin&cos
parent
7cbdb6e1f5
commit
9d400c262a
|
@ -140,8 +140,8 @@ class InternLMRotaryEmbedding(torch.nn.Module):
|
||||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
||||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
||||||
|
|
||||||
def forward(self, x, seq_len=None):
|
def forward(self, x, seq_len=None):
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
|
@ -152,11 +152,11 @@ class InternLMRotaryEmbedding(torch.nn.Module):
|
||||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
||||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
||||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
||||||
return (
|
return (
|
||||||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
self.cos_cached[:seq_len, ...].to(dtype=x.dtype),
|
||||||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
self.sin_cached[:seq_len, ...].to(dtype=x.dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -186,8 +186,8 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
|
||||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
||||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
||||||
|
|
||||||
def _update_cached(self, x, seq_len=None):
|
def _update_cached(self, x, seq_len=None):
|
||||||
self.max_seq_len_cached = max(seq_len, self.max_position_embeddings)
|
self.max_seq_len_cached = max(seq_len, self.max_position_embeddings)
|
||||||
|
@ -201,8 +201,8 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
|
||||||
t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype)
|
t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype)
|
||||||
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
||||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
||||||
|
|
||||||
def forward(self, x, seq_len=None):
|
def forward(self, x, seq_len=None):
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
|
@ -215,8 +215,8 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
|
||||||
self._update_cached(x, seq_len)
|
self._update_cached(x, seq_len)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
self.cos_cached[:seq_len, ...].to(dtype=x.dtype),
|
||||||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
self.sin_cached[:seq_len, ...].to(dtype=x.dtype),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -229,19 +229,10 @@ def rotate_half(x):
|
||||||
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
cos = cos[position_ids].unsqueeze(1).expand(q.shape)
|
||||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
sin = sin[position_ids].unsqueeze(1).expand(q.shape)
|
||||||
cos = cos.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
sin = sin.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
if q.size(2) == 1:
|
|
||||||
q_embed = (q * cos[:, :, -1, :]) + (rotate_half(q) * sin[:, :, -1, :])
|
|
||||||
else:
|
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
||||||
|
|
||||||
if k.size(2) == 1:
|
|
||||||
k_embed = (k * cos[:, :, -1, :]) + (rotate_half(k) * sin[:, :, -1, :])
|
|
||||||
else:
|
|
||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
||||||
|
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
@ -302,7 +293,7 @@ class InternLMAttention(nn.Module):
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
base=self.config.rope_theta,
|
base=self.config.rope_theta,
|
||||||
scaling_factor=scaling_factor
|
scaling_factor=scaling_factor,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Currently we only support rotary embedding's type being 'dynamic'.")
|
raise ValueError("Currently we only support rotary embedding's type being 'dynamic'.")
|
||||||
|
|
Loading…
Reference in New Issue