mirror of https://github.com/InternLM/InternLM
update hf model: add rope config and add qkv
parent
81ffb3d824
commit
d9262da635
|
@ -1,5 +1,5 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright (c) InternLM. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
|
@ -49,6 +49,14 @@ class InternLMConfig(PretrainedConfig):
|
|||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
|
@ -87,6 +95,7 @@ class InternLMConfig(PretrainedConfig):
|
|||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
|
@ -97,7 +106,8 @@ class InternLMConfig(PretrainedConfig):
|
|||
eos_token_id=2,
|
||||
tie_word_embeddings=False,
|
||||
bias=True,
|
||||
rotary={"base": 10000, "type": "dynamic"}, # pylint: disable=W0102
|
||||
rope_theta=10000,
|
||||
rope_scaling=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
|
@ -106,12 +116,19 @@ class InternLMConfig(PretrainedConfig):
|
|||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.bias = bias
|
||||
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.bias = bias
|
||||
self.rotary = rotary
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self._rope_scaling_validation()
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
|
@ -119,3 +136,24 @@ class InternLMConfig(PretrainedConfig):
|
|||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _rope_scaling_validation(self):
|
||||
"""
|
||||
Validate the `rope_scaling` configuration.
|
||||
"""
|
||||
if self.rope_scaling is None:
|
||||
return
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
|
||||
f"got {self.rope_scaling}"
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||
)
|
||||
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0:
|
||||
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}")
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright (c) InternLM. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
|
@ -28,7 +28,6 @@ import torch.utils.checkpoint
|
|||
from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.generation.streamers import BaseStreamer
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
CausalLMOutputWithPast,
|
||||
|
@ -42,6 +41,11 @@ from transformers.utils import (
|
|||
replace_return_docstrings,
|
||||
)
|
||||
|
||||
try:
|
||||
from transformers.generation.streamers import BaseStreamer
|
||||
except: # noqa # pylint: disable=bare-except
|
||||
BaseStreamer = None
|
||||
|
||||
from .configuration_internlm import InternLMConfig
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
@ -82,6 +86,17 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
|||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
(batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
class InternLMRMSNorm(nn.Module):
|
||||
"""RMSNorm implemention."""
|
||||
|
||||
|
@ -113,6 +128,7 @@ class InternLMRotaryEmbedding(torch.nn.Module):
|
|||
base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000.
|
||||
device (Any, optional): Running device. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
||||
|
@ -158,7 +174,7 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
|
|||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.dim = dim
|
||||
self.base = base
|
||||
self.scaling_factor = scaling_factor
|
||||
|
@ -256,6 +272,8 @@ class InternLMAttention(nn.Module):
|
|||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
|
@ -264,27 +282,30 @@ class InternLMAttention(nn.Module):
|
|||
f" and `num_heads`: {self.num_heads})."
|
||||
)
|
||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
||||
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.bias)
|
||||
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.bias)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
|
||||
self.rotary_emb = self._init_rope()
|
||||
|
||||
def _init_rope(self):
|
||||
if self.config.rotary["type"] == "origin":
|
||||
if self.config.rope_scaling is None:
|
||||
self.rotary_emb = InternLMRotaryEmbedding(
|
||||
self.head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
base=self.config.rotary["base"],
|
||||
)
|
||||
elif self.config.rotary["type"] == "dynamic":
|
||||
self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding(
|
||||
self.head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
base=self.config.rotary["base"],
|
||||
scaling_factor=self.config.rotary.get("scaling_factor", 1.0),
|
||||
base=self.config.rope_theta,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Currently we only support rotary embedding's type being one of ('origin', 'dynamic').")
|
||||
scaling_type = self.config.rope_scaling["type"]
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if scaling_type == "dynamic":
|
||||
self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding(
|
||||
self.head_dim,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
base=self.config.rope_theta,
|
||||
scaling_factor=scaling_factor
|
||||
)
|
||||
else:
|
||||
raise ValueError("Currently we only support rotary embedding's type being 'dynamic'.")
|
||||
return self.rotary_emb
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
|
@ -302,21 +323,27 @@ class InternLMAttention(nn.Module):
|
|||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = (
|
||||
self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
)
|
||||
value_states = (
|
||||
self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
)
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
# print(use_cache)
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
|
@ -907,6 +934,11 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
|||
('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
|
||||
('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
|
||||
"""
|
||||
if BaseStreamer is None:
|
||||
raise ModuleNotFoundError(
|
||||
"The version of `transformers` is too low. Please make sure "
|
||||
"that you have installed `transformers>=4.28.0`."
|
||||
)
|
||||
|
||||
response_queue = queue.Queue(maxsize=20)
|
||||
|
||||
|
|
Loading…
Reference in New Issue