mirror of https://github.com/InternLM/InternLM
support dynamic ntk in transformers
parent
b9c813a972
commit
139b754f29
|
@ -124,6 +124,65 @@ class InternLMRotaryEmbedding(torch.nn.Module):
|
|||
)
|
||||
|
||||
|
||||
class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
|
||||
"""实现dynamic ntk rope;
|
||||
|
||||
需要保证:
|
||||
1. 长度小于 seq len 时能够不断地复用;
|
||||
2. 长度超过 seq len 时,每一个 新的token,都需要一个新的base;
|
||||
|
||||
Args:
|
||||
InternLMRotaryEmbedding (_type_): _description_
|
||||
"""
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
self.dim = dim
|
||||
self.base = base
|
||||
self.scaling_factor = scaling_factor
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
||||
|
||||
def _update_cached(self, x, seq_len=None):
|
||||
self.max_seq_len_cached = max(seq_len, self.max_position_embeddings)
|
||||
if seq_len > self.max_position_embeddings:
|
||||
base = self.base * (
|
||||
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
|
||||
) ** (self.dim / (self.dim - 2))
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
|
||||
else:
|
||||
inv_freq = self.inv_freq
|
||||
t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
||||
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
||||
if seq_len <= self.max_position_embeddings:
|
||||
# Reset the tables if the sequence length has changed,
|
||||
if self.max_seq_len_cached > self.max_position_embeddings:
|
||||
self._update_cached(x, seq_len)
|
||||
else:
|
||||
self._update_cached(x, seq_len)
|
||||
|
||||
return (
|
||||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
|
@ -135,10 +194,18 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
|||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
cos = cos.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
|
||||
sin = sin.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
|
||||
if q.size(2) == 1:
|
||||
q_embed = (q * cos[:, :, -1, :]) + (rotate_half(q) * sin[:, :, -1, :])
|
||||
else:
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
|
||||
if k.size(2) == 1:
|
||||
k_embed = (k * cos[:, :, -1, :]) + (rotate_half(k) * sin[:, :, -1, :])
|
||||
else:
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
|
@ -179,7 +246,26 @@ class InternLMAttention(nn.Module):
|
|||
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
|
||||
self.rotary_emb = InternLMRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
|
||||
self.rotary_emb = self._init_rope()
|
||||
|
||||
def _init_rope(self):
|
||||
if self.config.rotary["type"] == "origin":
|
||||
self.rotary_emb = InternLMRotaryEmbedding(
|
||||
self.head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
base=self.config.rotary["base"],
|
||||
)
|
||||
elif self.config.rotary["type"] == "dynamic":
|
||||
self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding(
|
||||
self.head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
base=self.config.rotary["base"],
|
||||
scaling_factor=self.config.rotary.get("scaling_factor", 1.0)
|
||||
)
|
||||
else:
|
||||
raise ValueError("Currently we only support rotary embedding's type being one of ('origin', 'dynamic').")
|
||||
|
||||
return self.rotary_emb
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
@ -199,20 +285,18 @@ class InternLMAttention(nn.Module):
|
|||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
# print(use_cache)
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
|
|
Loading…
Reference in New Issue