mirror of https://github.com/InternLM/InternLM
update model
parent
9d400c262a
commit
695d76eb31
|
@ -140,8 +140,8 @@ class InternLMRotaryEmbedding(torch.nn.Module):
|
||||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||||
emb = torch.cat((freqs, freqs), dim=-1)
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
self.register_buffer("cos_cached", emb.cos(), persistent=False)
|
self.register_buffer("cos_cached", emb.cos().to(torch.float32), persistent=False)
|
||||||
self.register_buffer("sin_cached", emb.sin(), persistent=False)
|
self.register_buffer("sin_cached", emb.sin().to(torch.float32), persistent=False)
|
||||||
|
|
||||||
def forward(self, x, seq_len=None):
|
def forward(self, x, seq_len=None):
|
||||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||||
|
@ -228,12 +228,22 @@ def rotate_half(x):
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
if position_ids.size(1) == 1:
|
||||||
cos = cos[position_ids].unsqueeze(1).expand(q.shape)
|
q_cos = cos[position_ids].unsqueeze(1).expand(q.shape)
|
||||||
sin = sin[position_ids].unsqueeze(1).expand(q.shape)
|
q_sin = sin[position_ids].unsqueeze(1).expand(q.shape)
|
||||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
q_embed = (q * q_cos) + (rotate_half(q) * q_sin)
|
||||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
||||||
|
|
||||||
|
position_ids = position_ids.flatten() + 1
|
||||||
|
max_length = max(position_ids)
|
||||||
|
position_ids = torch.stack([torch.cat([torch.ones(max_length - w, dtype=torch.long), torch.arange(w)]) for w in position_ids])
|
||||||
|
k_cos = cos[position_ids].unsqueeze(1).expand(k.shape)
|
||||||
|
k_sin = sin[position_ids].unsqueeze(1).expand(k.shape)
|
||||||
|
k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
|
||||||
|
else:
|
||||||
|
cos = cos[position_ids].unsqueeze(1).expand(q.shape)
|
||||||
|
sin = sin[position_ids].unsqueeze(1).expand(q.shape)
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
return q_embed, k_embed
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue