diff --git a/tools/transformers/internlm_model/modeling_internlm.py b/tools/transformers/internlm_model/modeling_internlm.py index 18c9af6..a6066c6 100644 --- a/tools/transformers/internlm_model/modeling_internlm.py +++ b/tools/transformers/internlm_model/modeling_internlm.py @@ -140,8 +140,8 @@ class InternLMRotaryEmbedding(torch.nn.Module): freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos(), persistent=False) - self.register_buffer("sin_cached", emb.sin(), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(torch.float32), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(torch.float32), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] @@ -228,12 +228,22 @@ def rotate_half(x): def apply_rotary_pos_emb(q, k, cos, sin, position_ids): - # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. - cos = cos[position_ids].unsqueeze(1).expand(q.shape) - sin = sin[position_ids].unsqueeze(1).expand(q.shape) - q_embed = (q * cos) + (rotate_half(q) * sin) - k_embed = (k * cos) + (rotate_half(k) * sin) + if position_ids.size(1) == 1: + q_cos = cos[position_ids].unsqueeze(1).expand(q.shape) + q_sin = sin[position_ids].unsqueeze(1).expand(q.shape) + q_embed = (q * q_cos) + (rotate_half(q) * q_sin) + position_ids = position_ids.flatten() + 1 + max_length = max(position_ids) + position_ids = torch.stack([torch.cat([torch.ones(max_length - w, dtype=torch.long), torch.arange(w)]) for w in position_ids]) + k_cos = cos[position_ids].unsqueeze(1).expand(k.shape) + k_sin = sin[position_ids].unsqueeze(1).expand(k.shape) + k_embed = (k * k_cos) + (rotate_half(k) * k_sin) + else: + cos = cos[position_ids].unsqueeze(1).expand(q.shape) + sin = sin[position_ids].unsqueeze(1).expand(q.shape) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed