diff --git a/tools/transformers/internlm_model/modeling_internlm.py b/tools/transformers/internlm_model/modeling_internlm.py index 7f2bb1f..18c9af6 100644 --- a/tools/transformers/internlm_model/modeling_internlm.py +++ b/tools/transformers/internlm_model/modeling_internlm.py @@ -140,8 +140,8 @@ class InternLMRotaryEmbedding(torch.nn.Module): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) + self.register_buffer("cos_cached", emb.cos(), persistent=False) + self.register_buffer("sin_cached", emb.sin(), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -152,11 +152,11 @@ class InternLMRotaryEmbedding(torch.nn.Module): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) + self.register_buffer("cos_cached", emb.cos(), persistent=False) + self.register_buffer("sin_cached", emb.sin(), persistent=False) return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.cos_cached[:seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:seq_len, ...].to(dtype=x.dtype), ) @@ -186,8 +186,8 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) + self.register_buffer("cos_cached", emb.cos(), persistent=False) + self.register_buffer("sin_cached", emb.sin(), persistent=False) def _update_cached(self, x, seq_len=None): self.max_seq_len_cached = max(seq_len, self.max_position_embeddings) @@ -201,8 +201,8 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module): t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, inv_freq) emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) - self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) + self.register_buffer("cos_cached", emb.cos(), persistent=False) + self.register_buffer("sin_cached", emb.sin(), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -215,8 +215,8 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module): self._update_cached(x, seq_len) return ( - self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), - self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), + self.cos_cached[:seq_len, ...].to(dtype=x.dtype), + self.sin_cached[:seq_len, ...].to(dtype=x.dtype), ) @@ -229,19 +229,10 @@ def rotate_half(x): def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] - sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] - cos = cos.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1) - sin = sin.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1) - if q.size(2) == 1: - q_embed = (q * cos[:, :, -1, :]) + (rotate_half(q) * sin[:, :, -1, :]) - else: - q_embed = (q * cos) + (rotate_half(q) * sin) - - if k.size(2) == 1: - k_embed = (k * cos[:, :, -1, :]) + (rotate_half(k) * sin[:, :, -1, :]) - else: - k_embed = (k * cos) + (rotate_half(k) * sin) + cos = cos[position_ids].unsqueeze(1).expand(q.shape) + sin = sin[position_ids].unsqueeze(1).expand(q.shape) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed @@ -302,7 +293,7 @@ class InternLMAttention(nn.Module): self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.config.rope_theta, - scaling_factor=scaling_factor + scaling_factor=scaling_factor, ) else: raise ValueError("Currently we only support rotary embedding's type being 'dynamic'.")