update hf model: add rope config and add qkv

pull/536/head
x54-729 2023-12-12 12:19:20 +08:00
parent 81ffb3d824
commit d9262da635
2 changed files with 92 additions and 22 deletions

View File

@ -1,5 +1,5 @@
# coding=utf-8 # coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # Copyright (c) InternLM. All rights reserved.
# #
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its # and OPT implementations in this library. It has been modified from its
@ -49,6 +49,14 @@ class InternLMConfig(PretrainedConfig):
Number of hidden layers in the Transformer encoder. Number of hidden layers in the Transformer encoder.
num_attention_heads (`int`, *optional*, defaults to 32): num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer encoder. Number of attention heads for each attention layer in the Transformer encoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder. The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048): max_position_embeddings (`int`, *optional*, defaults to 2048):
@ -87,6 +95,7 @@ class InternLMConfig(PretrainedConfig):
intermediate_size=11008, intermediate_size=11008,
num_hidden_layers=32, num_hidden_layers=32,
num_attention_heads=32, num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu", hidden_act="silu",
max_position_embeddings=2048, max_position_embeddings=2048,
initializer_range=0.02, initializer_range=0.02,
@ -97,7 +106,8 @@ class InternLMConfig(PretrainedConfig):
eos_token_id=2, eos_token_id=2,
tie_word_embeddings=False, tie_word_embeddings=False,
bias=True, bias=True,
rotary={"base": 10000, "type": "dynamic"}, # pylint: disable=W0102 rope_theta=10000,
rope_scaling=None,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -106,12 +116,19 @@ class InternLMConfig(PretrainedConfig):
self.intermediate_size = intermediate_size self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.bias = bias
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act self.hidden_act = hidden_act
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache self.use_cache = use_cache
self.bias = bias self.rope_theta = rope_theta
self.rotary = rotary self.rope_scaling = rope_scaling
self._rope_scaling_validation()
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
@ -119,3 +136,24 @@ class InternLMConfig(PretrainedConfig):
tie_word_embeddings=tie_word_embeddings, tie_word_embeddings=tie_word_embeddings,
**kwargs, **kwargs,
) )
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}")

View File

@ -1,5 +1,5 @@
# coding=utf-8 # coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # Copyright (c) InternLM. All rights reserved.
# #
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its # and OPT implementations in this library. It has been modified from its
@ -28,7 +28,6 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.generation.streamers import BaseStreamer
from transformers.modeling_outputs import ( from transformers.modeling_outputs import (
BaseModelOutputWithPast, BaseModelOutputWithPast,
CausalLMOutputWithPast, CausalLMOutputWithPast,
@ -42,6 +41,11 @@ from transformers.utils import (
replace_return_docstrings, replace_return_docstrings,
) )
try:
from transformers.generation.streamers import BaseStreamer
except: # noqa # pylint: disable=bare-except
BaseStreamer = None
from .configuration_internlm import InternLMConfig from .configuration_internlm import InternLMConfig
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -82,6 +86,17 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
(batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class InternLMRMSNorm(nn.Module): class InternLMRMSNorm(nn.Module):
"""RMSNorm implemention.""" """RMSNorm implemention."""
@ -113,6 +128,7 @@ class InternLMRotaryEmbedding(torch.nn.Module):
base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000. base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000.
device (Any, optional): Running device. Defaults to None. device (Any, optional): Running device. Defaults to None.
""" """
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
@ -158,7 +174,7 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
super().__init__() super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq, persistent=False)
self.dim = dim self.dim = dim
self.base = base self.base = base
self.scaling_factor = scaling_factor self.scaling_factor = scaling_factor
@ -256,6 +272,8 @@ class InternLMAttention(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
if (self.head_dim * self.num_heads) != self.hidden_size: if (self.head_dim * self.num_heads) != self.hidden_size:
@ -264,27 +282,30 @@ class InternLMAttention(nn.Module):
f" and `num_heads`: {self.num_heads})." f" and `num_heads`: {self.num_heads})."
) )
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias)
self.rotary_emb = self._init_rope() self.rotary_emb = self._init_rope()
def _init_rope(self): def _init_rope(self):
if self.config.rotary["type"] == "origin": if self.config.rope_scaling is None:
self.rotary_emb = InternLMRotaryEmbedding( self.rotary_emb = InternLMRotaryEmbedding(
self.head_dim, self.head_dim,
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
base=self.config.rotary["base"], base=self.config.rope_theta,
)
elif self.config.rotary["type"] == "dynamic":
self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.config.rotary["base"],
scaling_factor=self.config.rotary.get("scaling_factor", 1.0),
) )
else: else:
raise ValueError("Currently we only support rotary embedding's type being one of ('origin', 'dynamic').") scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "dynamic":
self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.config.rope_theta,
scaling_factor=scaling_factor
)
else:
raise ValueError("Currently we only support rotary embedding's type being 'dynamic'.")
return self.rotary_emb return self.rotary_emb
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
@ -302,21 +323,27 @@ class InternLMAttention(nn.Module):
bsz, q_len, _ = hidden_states.size() bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = (
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
)
if past_key_value is not None: if past_key_value is not None:
# reuse k, v, self_attention # reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2) key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2)
# print(use_cache)
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
@ -907,6 +934,11 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')]) ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')])
('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')]) ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')])
""" """
if BaseStreamer is None:
raise ModuleNotFoundError(
"The version of `transformers` is too low. Please make sure "
"that you have installed `transformers>=4.28.0`."
)
response_queue = queue.Queue(maxsize=20) response_queue = queue.Queue(maxsize=20)