mirror of https://github.com/InternLM/InternLM
update hf model: add rope config and add qkv
parent
81ffb3d824
commit
d9262da635
|
@ -1,5 +1,5 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
# Copyright (c) InternLM. All rights reserved.
|
||||||
#
|
#
|
||||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
# and OPT implementations in this library. It has been modified from its
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
@ -49,6 +49,14 @@ class InternLMConfig(PretrainedConfig):
|
||||||
Number of hidden layers in the Transformer encoder.
|
Number of hidden layers in the Transformer encoder.
|
||||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||||
Number of attention heads for each attention layer in the Transformer encoder.
|
Number of attention heads for each attention layer in the Transformer encoder.
|
||||||
|
num_key_value_heads (`int`, *optional*):
|
||||||
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||||
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||||
|
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||||
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||||
|
by meanpooling all the original heads within that group. For more details checkout [this
|
||||||
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||||
|
`num_attention_heads`.
|
||||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||||
The non-linear activation function (function or string) in the decoder.
|
The non-linear activation function (function or string) in the decoder.
|
||||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||||
|
@ -87,6 +95,7 @@ class InternLMConfig(PretrainedConfig):
|
||||||
intermediate_size=11008,
|
intermediate_size=11008,
|
||||||
num_hidden_layers=32,
|
num_hidden_layers=32,
|
||||||
num_attention_heads=32,
|
num_attention_heads=32,
|
||||||
|
num_key_value_heads=None,
|
||||||
hidden_act="silu",
|
hidden_act="silu",
|
||||||
max_position_embeddings=2048,
|
max_position_embeddings=2048,
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
|
@ -97,7 +106,8 @@ class InternLMConfig(PretrainedConfig):
|
||||||
eos_token_id=2,
|
eos_token_id=2,
|
||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
bias=True,
|
bias=True,
|
||||||
rotary={"base": 10000, "type": "dynamic"}, # pylint: disable=W0102
|
rope_theta=10000,
|
||||||
|
rope_scaling=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
@ -106,12 +116,19 @@ class InternLMConfig(PretrainedConfig):
|
||||||
self.intermediate_size = intermediate_size
|
self.intermediate_size = intermediate_size
|
||||||
self.num_hidden_layers = num_hidden_layers
|
self.num_hidden_layers = num_hidden_layers
|
||||||
self.num_attention_heads = num_attention_heads
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.bias = bias
|
||||||
|
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
|
||||||
self.hidden_act = hidden_act
|
self.hidden_act = hidden_act
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
self.rms_norm_eps = rms_norm_eps
|
self.rms_norm_eps = rms_norm_eps
|
||||||
self.use_cache = use_cache
|
self.use_cache = use_cache
|
||||||
self.bias = bias
|
self.rope_theta = rope_theta
|
||||||
self.rotary = rotary
|
self.rope_scaling = rope_scaling
|
||||||
|
self._rope_scaling_validation()
|
||||||
super().__init__(
|
super().__init__(
|
||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
bos_token_id=bos_token_id,
|
bos_token_id=bos_token_id,
|
||||||
|
@ -119,3 +136,24 @@ class InternLMConfig(PretrainedConfig):
|
||||||
tie_word_embeddings=tie_word_embeddings,
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _rope_scaling_validation(self):
|
||||||
|
"""
|
||||||
|
Validate the `rope_scaling` configuration.
|
||||||
|
"""
|
||||||
|
if self.rope_scaling is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
|
||||||
|
f"got {self.rope_scaling}"
|
||||||
|
)
|
||||||
|
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||||
|
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||||
|
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||||
|
raise ValueError(
|
||||||
|
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||||
|
)
|
||||||
|
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0:
|
||||||
|
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}")
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
# Copyright (c) InternLM. All rights reserved.
|
||||||
#
|
#
|
||||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
# and OPT implementations in this library. It has been modified from its
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
@ -28,7 +28,6 @@ import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.generation.streamers import BaseStreamer
|
|
||||||
from transformers.modeling_outputs import (
|
from transformers.modeling_outputs import (
|
||||||
BaseModelOutputWithPast,
|
BaseModelOutputWithPast,
|
||||||
CausalLMOutputWithPast,
|
CausalLMOutputWithPast,
|
||||||
|
@ -42,6 +41,11 @@ from transformers.utils import (
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers.generation.streamers import BaseStreamer
|
||||||
|
except: # noqa # pylint: disable=bare-except
|
||||||
|
BaseStreamer = None
|
||||||
|
|
||||||
from .configuration_internlm import InternLMConfig
|
from .configuration_internlm import InternLMConfig
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
@ -82,6 +86,17 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
||||||
|
|
||||||
|
|
||||||
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
(batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||||
|
"""
|
||||||
|
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||||
|
if n_rep == 1:
|
||||||
|
return hidden_states
|
||||||
|
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
||||||
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
class InternLMRMSNorm(nn.Module):
|
class InternLMRMSNorm(nn.Module):
|
||||||
"""RMSNorm implemention."""
|
"""RMSNorm implemention."""
|
||||||
|
|
||||||
|
@ -113,6 +128,7 @@ class InternLMRotaryEmbedding(torch.nn.Module):
|
||||||
base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000.
|
base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000.
|
||||||
device (Any, optional): Running device. Defaults to None.
|
device (Any, optional): Running device. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
||||||
|
@ -158,7 +174,7 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
|
||||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
||||||
self.register_buffer("inv_freq", inv_freq)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.base = base
|
self.base = base
|
||||||
self.scaling_factor = scaling_factor
|
self.scaling_factor = scaling_factor
|
||||||
|
@ -256,6 +272,8 @@ class InternLMAttention(nn.Module):
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.head_dim = self.hidden_size // self.num_heads
|
self.head_dim = self.hidden_size // self.num_heads
|
||||||
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
self.max_position_embeddings = config.max_position_embeddings
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
|
|
||||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||||
|
@ -264,27 +282,30 @@ class InternLMAttention(nn.Module):
|
||||||
f" and `num_heads`: {self.num_heads})."
|
f" and `num_heads`: {self.num_heads})."
|
||||||
)
|
)
|
||||||
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
||||||
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.bias)
|
||||||
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.bias)
|
||||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
|
||||||
self.rotary_emb = self._init_rope()
|
self.rotary_emb = self._init_rope()
|
||||||
|
|
||||||
def _init_rope(self):
|
def _init_rope(self):
|
||||||
if self.config.rotary["type"] == "origin":
|
if self.config.rope_scaling is None:
|
||||||
self.rotary_emb = InternLMRotaryEmbedding(
|
self.rotary_emb = InternLMRotaryEmbedding(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
base=self.config.rotary["base"],
|
base=self.config.rope_theta,
|
||||||
)
|
|
||||||
elif self.config.rotary["type"] == "dynamic":
|
|
||||||
self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding(
|
|
||||||
self.head_dim,
|
|
||||||
max_position_embeddings=self.max_position_embeddings,
|
|
||||||
base=self.config.rotary["base"],
|
|
||||||
scaling_factor=self.config.rotary.get("scaling_factor", 1.0),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Currently we only support rotary embedding's type being one of ('origin', 'dynamic').")
|
scaling_type = self.config.rope_scaling["type"]
|
||||||
|
scaling_factor = self.config.rope_scaling["factor"]
|
||||||
|
if scaling_type == "dynamic":
|
||||||
|
self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding(
|
||||||
|
self.head_dim,
|
||||||
|
max_position_embeddings=self.max_position_embeddings,
|
||||||
|
base=self.config.rope_theta,
|
||||||
|
scaling_factor=scaling_factor
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Currently we only support rotary embedding's type being 'dynamic'.")
|
||||||
return self.rotary_emb
|
return self.rotary_emb
|
||||||
|
|
||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
|
@ -302,21 +323,27 @@ class InternLMAttention(nn.Module):
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
key_states = (
|
||||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
)
|
||||||
|
value_states = (
|
||||||
|
self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||||
|
)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# reuse k, v, self_attention
|
# reuse k, v, self_attention
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
|
|
||||||
# print(use_cache)
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
kv_seq_len = key_states.shape[-2]
|
||||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||||
|
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||||
|
|
||||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||||
|
@ -907,6 +934,11 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
|
||||||
('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
|
('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
|
||||||
('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
|
('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
|
||||||
"""
|
"""
|
||||||
|
if BaseStreamer is None:
|
||||||
|
raise ModuleNotFoundError(
|
||||||
|
"The version of `transformers` is too low. Please make sure "
|
||||||
|
"that you have installed `transformers>=4.28.0`."
|
||||||
|
)
|
||||||
|
|
||||||
response_queue = queue.Queue(maxsize=20)
|
response_queue = queue.Queue(maxsize=20)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue