feat(tools): support dynamic ntk rope in transformers (#470)

* support dynamic ntk in transformers

* support dynamic ntk in transformers

* support dynamic ntk in transformers

* add rope doc

* add rotary config in configuration_internlm.py

---------

Co-authored-by: YWMditto <862779238@qq.com>
pull/475/head
YWMditto 2023-11-06 23:15:06 +08:00 committed by GitHub
parent 42ad9cc786
commit 095ebfff9d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 170 additions and 57 deletions

View File

@ -19,9 +19,8 @@
# limitations under the License. # limitations under the License.
""" InternLM model configuration""" """ InternLM model configuration"""
from transformers.utils import logging
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -30,9 +29,9 @@ INTERNLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
class InternLMConfig(PretrainedConfig): class InternLMConfig(PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`InternLMModel`]. It is used to instantiate an InternLM This is the configuration class to store the configuration of a [`InternLMModel`]. It is used to instantiate
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the an InternLM model according to the specified arguments, defining the model architecture. Instantiating a
defaults will yield a similar configuration to that of the InternLM-7B. configuration with the defaults will yield a similar configuration to that of the InternLM-7B.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
@ -81,7 +80,7 @@ class InternLMConfig(PretrainedConfig):
model_type = "internlm" model_type = "internlm"
_auto_class = "AutoConfig" _auto_class = "AutoConfig"
def __init__( def __init__( # pylint: disable=W0102
self, self,
vocab_size=103168, vocab_size=103168,
hidden_size=4096, hidden_size=4096,
@ -98,6 +97,7 @@ class InternLMConfig(PretrainedConfig):
eos_token_id=2, eos_token_id=2,
tie_word_embeddings=False, tie_word_embeddings=False,
bias=True, bias=True,
rotary={"base": 10000, "type": "dynamic"}, # pylint: disable=W0102
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -111,6 +111,7 @@ class InternLMConfig(PretrainedConfig):
self.rms_norm_eps = rms_norm_eps self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache self.use_cache = use_cache
self.bias = bias self.bias = bias
self.rotary = rotary
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,

View File

@ -19,26 +19,36 @@
# limitations under the License. # limitations under the License.
""" PyTorch InternLM model.""" """ PyTorch InternLM model."""
import math import math
import queue
import threading
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import threading, queue
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.streamers import BaseStreamer from transformers.generation.streamers import BaseStreamer
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from transformers.modeling_outputs import (
from .configuration_internlm import InternLMConfig BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_internlm import InternLMConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "InternLMConfig" _CONFIG_FOR_DOC = "InternLMConfig"
# Copied from transformers.models.bart.modeling_bart._make_causal_mask # Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask( def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
@ -73,6 +83,8 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
class InternLMRMSNorm(nn.Module): class InternLMRMSNorm(nn.Module):
"""RMSNorm implemention."""
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
InternLMRMSNorm is equivalent to T5LayerNorm InternLMRMSNorm is equivalent to T5LayerNorm
@ -93,6 +105,14 @@ class InternLMRMSNorm(nn.Module):
class InternLMRotaryEmbedding(torch.nn.Module): class InternLMRotaryEmbedding(torch.nn.Module):
"""Implement InternLM's rotary embedding.
Args:
dim (int): Characteristic dimension of each self-attentional head.
max_position_embeddings (int, optional): Model's training length. Defaults to 2048.
base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000.
device (Any, optional): Running device. Defaults to None.
"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
@ -124,6 +144,66 @@ class InternLMRotaryEmbedding(torch.nn.Module):
) )
class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
"""Implement InternLM's DyanmicNTK extrapolation method, thereby broadening the model support context to 16K.
Args:
dim (int): Characteristic dimension of each self-attentional head.
max_position_embeddings (int, optional): Model's training length. Defaults to 2048.
base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000.
device (Any, optional): Running device. Defaults to None.
scaling_factor (float, optional): NTK method extrapolation coefficient. Defaults to 1.0.
"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
self.dim = dim
self.base = base
self.scaling_factor = scaling_factor
# Build here to make `torch.jit.trace` work.
self.max_position_embeddings = max_position_embeddings
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
def _update_cached(self, x, seq_len=None):
self.max_seq_len_cached = max(seq_len, self.max_position_embeddings)
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(x.device) / self.dim))
else:
inv_freq = self.inv_freq
t = torch.arange(self.max_seq_len_cached, device=inv_freq.device, dtype=inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len <= self.max_position_embeddings:
# Reset the tables if the sequence length has changed,
if self.max_seq_len_cached > self.max_position_embeddings:
self._update_cached(x, seq_len)
else:
self._update_cached(x, seq_len)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
def rotate_half(x): def rotate_half(x):
"""Rotates half the hidden dims of the input.""" """Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2] x1 = x[..., : x.shape[-1] // 2]
@ -135,10 +215,18 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them. # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] cos = cos.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin.unsqueeze(0).unsqueeze(0).expand(len(position_ids), -1, -1, -1)
if q.size(2) == 1:
q_embed = (q * cos[:, :, -1, :]) + (rotate_half(q) * sin[:, :, -1, :])
else:
q_embed = (q * cos) + (rotate_half(q) * sin) q_embed = (q * cos) + (rotate_half(q) * sin)
if k.size(2) == 1:
k_embed = (k * cos[:, :, -1, :]) + (rotate_half(k) * sin[:, :, -1, :])
else:
k_embed = (k * cos) + (rotate_half(k) * sin) k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed return q_embed, k_embed
@ -179,7 +267,25 @@ class InternLMAttention(nn.Module):
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
self.rotary_emb = InternLMRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) self.rotary_emb = self._init_rope()
def _init_rope(self):
if self.config.rotary["type"] == "origin":
self.rotary_emb = InternLMRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.config.rotary["base"],
)
elif self.config.rotary["type"] == "dynamic":
self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.config.rotary["base"],
scaling_factor=self.config.rotary.get("scaling_factor", 1.0),
)
else:
raise ValueError("Currently we only support rotary embedding's type being one of ('origin', 'dynamic').")
return self.rotary_emb
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
@ -199,20 +305,18 @@ class InternLMAttention(nn.Module):
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None: if past_key_value is not None:
# reuse k, v, self_attention # reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2) key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2)
# print(use_cache)
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None
kv_seq_len = key_states.shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
@ -386,7 +490,8 @@ INTERNLM_INPUTS_DOCSTRING = r"""
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`. config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids) [What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or
when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
@ -423,6 +528,7 @@ class InternLMModel(InternLMPreTrainedModel):
Args: Args:
config: InternLMConfig config: InternLMConfig
""" """
_auto_class = "AutoModel" _auto_class = "AutoModel"
def __init__(self, config: InternLMConfig): def __init__(self, config: InternLMConfig):
@ -754,7 +860,8 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
return tokenizer([prompt], return_tensors="pt") return tokenizer([prompt], return_tensors="pt")
@torch.no_grad() @torch.no_grad()
def chat(self, def chat(
self,
tokenizer, tokenizer,
query: str, query: str,
history: List[Tuple[str, str]] = [], history: List[Tuple[str, str]] = [],
@ -763,16 +870,19 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
do_sample: bool = True, do_sample: bool = True,
temperature: float = 0.8, temperature: float = 0.8,
top_p: float = 0.8, top_p: float = 0.8,
**kwargs): **kwargs,
):
inputs = self.build_inputs(tokenizer, query, history) inputs = self.build_inputs(tokenizer, query, history)
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)} inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
outputs = self.generate(**inputs, outputs = self.generate(
**inputs,
streamer=streamer, streamer=streamer,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
do_sample=do_sample, do_sample=do_sample,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
**kwargs) **kwargs,
)
outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :] outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
response = tokenizer.decode(outputs, skip_special_tokens=True) response = tokenizer.decode(outputs, skip_special_tokens=True)
response = response.split("<eoa>")[0] response = response.split("<eoa>")[0]
@ -780,7 +890,8 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
return response, history return response, history
@torch.no_grad() @torch.no_grad()
def stream_chat(self, def stream_chat(
self,
tokenizer, tokenizer,
query: str, query: str,
history: List[Tuple[str, str]] = [], history: List[Tuple[str, str]] = [],
@ -788,7 +899,8 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
do_sample: bool = True, do_sample: bool = True,
temperature: float = 0.8, temperature: float = 0.8,
top_p: float = 0.8, top_p: float = 0.8,
**kwargs): **kwargs,
):
""" """
Return a generator in format: (response, history) Return a generator in format: (response, history)
Eg. Eg.
@ -839,7 +951,7 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
do_sample=do_sample, do_sample=do_sample,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
**kwargs **kwargs,
) )
def consumer(): def consumer():