update model

pull/536/head
x54-729 2024-01-03 15:21:37 +08:00
parent 9d400c262a
commit 695d76eb31
1 changed files with 17 additions and 7 deletions

View File

@ -140,8 +140,8 @@ class InternLMRotaryEmbedding(torch.nn.Module):
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
self.register_buffer("cos_cached", emb.cos().to(torch.float32), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(torch.float32), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
@ -228,12 +228,22 @@ def rotate_half(x):
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos[position_ids].unsqueeze(1).expand(q.shape)
sin = sin[position_ids].unsqueeze(1).expand(q.shape)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
if position_ids.size(1) == 1:
q_cos = cos[position_ids].unsqueeze(1).expand(q.shape)
q_sin = sin[position_ids].unsqueeze(1).expand(q.shape)
q_embed = (q * q_cos) + (rotate_half(q) * q_sin)
position_ids = position_ids.flatten() + 1
max_length = max(position_ids)
position_ids = torch.stack([torch.cat([torch.ones(max_length - w, dtype=torch.long), torch.arange(w)]) for w in position_ids])
k_cos = cos[position_ids].unsqueeze(1).expand(k.shape)
k_sin = sin[position_ids].unsqueeze(1).expand(k.shape)
k_embed = (k * k_cos) + (rotate_half(k) * k_sin)
else:
cos = cos[position_ids].unsqueeze(1).expand(q.shape)
sin = sin[position_ids].unsqueeze(1).expand(q.shape)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed