support dynamic ntk in transformers

pull/470/head
YWMditto 2023-11-03 16:41:55 +08:00
parent 139b754f29
commit c196825551
1 changed files with 66 additions and 46 deletions

View File

@ -19,26 +19,36 @@
# limitations under the License.
""" PyTorch InternLM model."""
import math
import queue
import threading
from typing import List, Optional, Tuple, Union
import threading, queue
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.streamers import BaseStreamer
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_internlm import InternLMConfig
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_internlm import InternLMConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "InternLMConfig"
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
@ -73,6 +83,8 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
class InternLMRMSNorm(nn.Module):
"""RMSNorm implemention."""
def __init__(self, hidden_size, eps=1e-6):
"""
InternLMRMSNorm is equivalent to T5LayerNorm
@ -134,6 +146,7 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
Args:
InternLMRotaryEmbedding (_type_): _description_
"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
@ -167,7 +180,6 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
@ -183,6 +195,7 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
@ -260,11 +273,10 @@ class InternLMAttention(nn.Module):
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.config.rotary["base"],
scaling_factor=self.config.rotary.get("scaling_factor", 1.0)
scaling_factor=self.config.rotary.get("scaling_factor", 1.0),
)
else:
raise ValueError("Currently we only support rotary embedding's type being one of ('origin', 'dynamic').")
return self.rotary_emb
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
@ -470,7 +482,8 @@ INTERNLM_INPUTS_DOCSTRING = r"""
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or
when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
@ -507,6 +520,7 @@ class InternLMModel(InternLMPreTrainedModel):
Args:
config: InternLMConfig
"""
_auto_class = "AutoModel"
def __init__(self, config: InternLMConfig):
@ -838,7 +852,8 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
return tokenizer([prompt], return_tensors="pt")
@torch.no_grad()
def chat(self,
def chat(
self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = [],
@ -847,24 +862,28 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
do_sample: bool = True,
temperature: float = 0.8,
top_p: float = 0.8,
**kwargs):
**kwargs,
):
inputs = self.build_inputs(tokenizer, query, history)
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
outputs = self.generate(**inputs,
outputs = self.generate(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
**kwargs)
outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]):]
**kwargs,
)
outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
response = tokenizer.decode(outputs, skip_special_tokens=True)
response = response.split("<eoa>")[0]
history = history + [(query, response)]
return response, history
@torch.no_grad()
def stream_chat(self,
def stream_chat(
self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = [],
@ -872,7 +891,8 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
do_sample: bool = True,
temperature: float = 0.8,
top_p: float = 0.8,
**kwargs):
**kwargs,
):
"""
Return a generator in format: (response, history)
Eg.
@ -923,7 +943,7 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
**kwargs
**kwargs,
)
def consumer():