support dynamic ntk in transformers

pull/470/head
YWMditto 2023-11-03 16:41:55 +08:00
parent 139b754f29
commit c196825551
1 changed files with 66 additions and 46 deletions

View File

@ -19,26 +19,36 @@
# limitations under the License. # limitations under the License.
""" PyTorch InternLM model.""" """ PyTorch InternLM model."""
import math import math
import queue
import threading
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import threading, queue
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.streamers import BaseStreamer from transformers.generation.streamers import BaseStreamer
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from transformers.modeling_outputs import (
from .configuration_internlm import InternLMConfig BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_internlm import InternLMConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "InternLMConfig" _CONFIG_FOR_DOC = "InternLMConfig"
# Copied from transformers.models.bart.modeling_bart._make_causal_mask # Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask( def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
@ -73,6 +83,8 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
class InternLMRMSNorm(nn.Module): class InternLMRMSNorm(nn.Module):
"""RMSNorm implemention."""
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
""" """
InternLMRMSNorm is equivalent to T5LayerNorm InternLMRMSNorm is equivalent to T5LayerNorm
@ -134,6 +146,7 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
Args: Args:
InternLMRotaryEmbedding (_type_): _description_ InternLMRotaryEmbedding (_type_): _description_
""" """
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__() super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
@ -151,7 +164,7 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
def _update_cached(self, x, seq_len=None): def _update_cached(self, x, seq_len=None):
self.max_seq_len_cached = max(seq_len, self.max_position_embeddings) self.max_seq_len_cached = max(seq_len, self.max_position_embeddings)
if seq_len > self.max_position_embeddings: if seq_len > self.max_position_embeddings:
@ -166,7 +179,6 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
def forward(self, x, seq_len=None): def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
@ -177,12 +189,13 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
self._update_cached(x, seq_len) self._update_cached(x, seq_len)
else: else:
self._update_cached(x, seq_len) self._update_cached(x, seq_len)
return ( return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
) )
def rotate_half(x): def rotate_half(x):
"""Rotates half the hidden dims of the input.""" """Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2] x1 = x[..., : x.shape[-1] // 2]
@ -251,20 +264,19 @@ class InternLMAttention(nn.Module):
def _init_rope(self): def _init_rope(self):
if self.config.rotary["type"] == "origin": if self.config.rotary["type"] == "origin":
self.rotary_emb = InternLMRotaryEmbedding( self.rotary_emb = InternLMRotaryEmbedding(
self.head_dim, self.head_dim,
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
base=self.config.rotary["base"], base=self.config.rotary["base"],
) )
elif self.config.rotary["type"] == "dynamic": elif self.config.rotary["type"] == "dynamic":
self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding( self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding(
self.head_dim, self.head_dim,
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
base=self.config.rotary["base"], base=self.config.rotary["base"],
scaling_factor=self.config.rotary.get("scaling_factor", 1.0) scaling_factor=self.config.rotary.get("scaling_factor", 1.0),
) )
else: else:
raise ValueError("Currently we only support rotary embedding's type being one of ('origin', 'dynamic').") raise ValueError("Currently we only support rotary embedding's type being one of ('origin', 'dynamic').")
return self.rotary_emb return self.rotary_emb
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
@ -470,7 +482,8 @@ INTERNLM_INPUTS_DOCSTRING = r"""
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`. config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids) [What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or
when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
@ -507,6 +520,7 @@ class InternLMModel(InternLMPreTrainedModel):
Args: Args:
config: InternLMConfig config: InternLMConfig
""" """
_auto_class = "AutoModel" _auto_class = "AutoModel"
def __init__(self, config: InternLMConfig): def __init__(self, config: InternLMConfig):
@ -829,50 +843,56 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
for layer_past in past_key_values: for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past return reordered_past
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []): def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []):
prompt = "" prompt = ""
for record in history: for record in history:
prompt += f"""<|User|>:{record[0]}<eoh>\n<|Bot|>:{record[1]}<eoa>\n""" prompt += f"""<|User|>:{record[0]}<eoh>\n<|Bot|>:{record[1]}<eoa>\n"""
prompt += f"""<|User|>:{query}<eoh>\n<|Bot|>:""" prompt += f"""<|User|>:{query}<eoh>\n<|Bot|>:"""
return tokenizer([prompt], return_tensors="pt") return tokenizer([prompt], return_tensors="pt")
@torch.no_grad() @torch.no_grad()
def chat(self, def chat(
tokenizer, self,
query: str, tokenizer,
history: List[Tuple[str, str]] = [], query: str,
streamer: Optional[BaseStreamer] = None, history: List[Tuple[str, str]] = [],
max_new_tokens: int = 1024, streamer: Optional[BaseStreamer] = None,
do_sample: bool = True, max_new_tokens: int = 1024,
temperature: float = 0.8, do_sample: bool = True,
top_p: float = 0.8, temperature: float = 0.8,
**kwargs): top_p: float = 0.8,
**kwargs,
):
inputs = self.build_inputs(tokenizer, query, history) inputs = self.build_inputs(tokenizer, query, history)
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)} inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
outputs = self.generate(**inputs, outputs = self.generate(
streamer=streamer, **inputs,
max_new_tokens=max_new_tokens, streamer=streamer,
do_sample=do_sample, max_new_tokens=max_new_tokens,
temperature=temperature, do_sample=do_sample,
top_p=top_p, temperature=temperature,
**kwargs) top_p=top_p,
outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]):] **kwargs,
)
outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
response = tokenizer.decode(outputs, skip_special_tokens=True) response = tokenizer.decode(outputs, skip_special_tokens=True)
response = response.split("<eoa>")[0] response = response.split("<eoa>")[0]
history = history + [(query, response)] history = history + [(query, response)]
return response, history return response, history
@torch.no_grad() @torch.no_grad()
def stream_chat(self, def stream_chat(
tokenizer, self,
query: str, tokenizer,
history: List[Tuple[str, str]] = [], query: str,
max_new_tokens: int = 1024, history: List[Tuple[str, str]] = [],
do_sample: bool = True, max_new_tokens: int = 1024,
temperature: float = 0.8, do_sample: bool = True,
top_p: float = 0.8, temperature: float = 0.8,
**kwargs): top_p: float = 0.8,
**kwargs,
):
""" """
Return a generator in format: (response, history) Return a generator in format: (response, history)
Eg. Eg.
@ -918,12 +938,12 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
tokenizer=tokenizer, tokenizer=tokenizer,
query=query, query=query,
streamer=ChatStreamer(tokenizer=tokenizer), streamer=ChatStreamer(tokenizer=tokenizer),
history=history, history=history,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
do_sample=do_sample, do_sample=do_sample,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
**kwargs **kwargs,
) )
def consumer(): def consumer():