update hf model: add rope config and add qkv

pull/536/head
x54-729 2023-12-12 12:19:20 +08:00
parent 81ffb3d824
commit d9262da635
2 changed files with 92 additions and 22 deletions

View File

@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# Copyright (c) InternLM. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
@ -49,6 +49,14 @@ class InternLMConfig(PretrainedConfig):
Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
@ -87,6 +95,7 @@ class InternLMConfig(PretrainedConfig):
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
@ -97,7 +106,8 @@ class InternLMConfig(PretrainedConfig):
eos_token_id=2,
tie_word_embeddings=False,
bias=True,
rotary={"base": 10000, "type": "dynamic"}, # pylint: disable=W0102
rope_theta=10000,
rope_scaling=None,
**kwargs,
):
self.vocab_size = vocab_size
@ -106,12 +116,19 @@ class InternLMConfig(PretrainedConfig):
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.bias = bias
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.bias = bias
self.rotary = rotary
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
@ -119,3 +136,24 @@ class InternLMConfig(PretrainedConfig):
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}")

View File

@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
# Copyright (c) InternLM. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
@ -28,7 +28,6 @@ import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.generation.streamers import BaseStreamer
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
@ -42,6 +41,11 @@ from transformers.utils import (
replace_return_docstrings,
)
try:
from transformers.generation.streamers import BaseStreamer
except: # noqa # pylint: disable=bare-except
BaseStreamer = None
from .configuration_internlm import InternLMConfig
logger = logging.get_logger(__name__)
@ -82,6 +86,17 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
(batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class InternLMRMSNorm(nn.Module):
"""RMSNorm implemention."""
@ -113,6 +128,7 @@ class InternLMRotaryEmbedding(torch.nn.Module):
base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000.
device (Any, optional): Running device. Defaults to None.
"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
@ -158,7 +174,7 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.dim = dim
self.base = base
self.scaling_factor = scaling_factor
@ -256,6 +272,8 @@ class InternLMAttention(nn.Module):
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
if (self.head_dim * self.num_heads) != self.hidden_size:
@ -264,27 +282,30 @@ class InternLMAttention(nn.Module):
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
self.rotary_emb = self._init_rope()
def _init_rope(self):
if self.config.rotary["type"] == "origin":
if self.config.rope_scaling is None:
self.rotary_emb = InternLMRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.config.rotary["base"],
)
elif self.config.rotary["type"] == "dynamic":
self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.config.rotary["base"],
scaling_factor=self.config.rotary.get("scaling_factor", 1.0),
base=self.config.rope_theta,
)
else:
raise ValueError("Currently we only support rotary embedding's type being one of ('origin', 'dynamic').")
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "dynamic":
self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.config.rope_theta,
scaling_factor=scaling_factor
)
else:
raise ValueError("Currently we only support rotary embedding's type being 'dynamic'.")
return self.rotary_emb
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
@ -302,21 +323,27 @@ class InternLMAttention(nn.Module):
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = (
self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
# print(use_cache)
past_key_value = (key_states, value_states) if use_cache else None
kv_seq_len = key_states.shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
@ -907,6 +934,11 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
"""
if BaseStreamer is None:
raise ModuleNotFoundError(
"The version of `transformers` is too low. Please make sure "
"that you have installed `transformers>=4.28.0`."
)
response_queue = queue.Queue(maxsize=20)