From d9262da635039bdf218114ab371d213984b630f3 Mon Sep 17 00:00:00 2001 From: x54-729 Date: Tue, 12 Dec 2023 12:19:20 +0800 Subject: [PATCH] update hf model: add rope config and add qkv --- .../internlm_model/configuration_internlm.py | 46 +++++++++++-- .../internlm_model/modeling_internlm.py | 68 ++++++++++++++----- 2 files changed, 92 insertions(+), 22 deletions(-) diff --git a/tools/transformers/internlm_model/configuration_internlm.py b/tools/transformers/internlm_model/configuration_internlm.py index a76c1b8..0965741 100644 --- a/tools/transformers/internlm_model/configuration_internlm.py +++ b/tools/transformers/internlm_model/configuration_internlm.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# Copyright (c) InternLM. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its @@ -49,6 +49,14 @@ class InternLMConfig(PretrainedConfig): Number of hidden layers in the Transformer encoder. num_attention_heads (`int`, *optional*, defaults to 32): Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): The non-linear activation function (function or string) in the decoder. max_position_embeddings (`int`, *optional*, defaults to 2048): @@ -87,6 +95,7 @@ class InternLMConfig(PretrainedConfig): intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, + num_key_value_heads=None, hidden_act="silu", max_position_embeddings=2048, initializer_range=0.02, @@ -97,7 +106,8 @@ class InternLMConfig(PretrainedConfig): eos_token_id=2, tie_word_embeddings=False, bias=True, - rotary={"base": 10000, "type": "dynamic"}, # pylint: disable=W0102 + rope_theta=10000, + rope_scaling=None, **kwargs, ): self.vocab_size = vocab_size @@ -106,12 +116,19 @@ class InternLMConfig(PretrainedConfig): self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads + self.bias = bias + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache - self.bias = bias - self.rotary = rotary + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self._rope_scaling_validation() super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, @@ -119,3 +136,24 @@ class InternLMConfig(PretrainedConfig): tie_word_embeddings=tie_word_embeddings, **kwargs, ) + + def _rope_scaling_validation(self): + """ + Validate the `rope_scaling` configuration. + """ + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2: + raise ValueError( + "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_factor = self.rope_scaling.get("factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]: + raise ValueError( + f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}" + ) + if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor < 1.0: + raise ValueError(f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}") diff --git a/tools/transformers/internlm_model/modeling_internlm.py b/tools/transformers/internlm_model/modeling_internlm.py index e2d52ed..7f2bb1f 100644 --- a/tools/transformers/internlm_model/modeling_internlm.py +++ b/tools/transformers/internlm_model/modeling_internlm.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# Copyright (c) InternLM. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its @@ -28,7 +28,6 @@ import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN -from transformers.generation.streamers import BaseStreamer from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -42,6 +41,11 @@ from transformers.utils import ( replace_return_docstrings, ) +try: + from transformers.generation.streamers import BaseStreamer +except: # noqa # pylint: disable=bare-except + BaseStreamer = None + from .configuration_internlm import InternLMConfig logger = logging.get_logger(__name__) @@ -82,6 +86,17 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + class InternLMRMSNorm(nn.Module): """RMSNorm implemention.""" @@ -113,6 +128,7 @@ class InternLMRotaryEmbedding(torch.nn.Module): base (int, optional): The rotation position encodes the rotation Angle base number. Defaults to 10000. device (Any, optional): Running device. Defaults to None. """ + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) @@ -158,7 +174,7 @@ class InternLMDynamicNTKScalingRotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer("inv_freq", inv_freq) + self.register_buffer("inv_freq", inv_freq, persistent=False) self.dim = dim self.base = base self.scaling_factor = scaling_factor @@ -256,6 +272,8 @@ class InternLMAttention(nn.Module): self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings if (self.head_dim * self.num_heads) != self.hidden_size: @@ -264,27 +282,30 @@ class InternLMAttention(nn.Module): f" and `num_heads`: {self.num_heads})." ) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) - self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) - self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.bias) + self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.bias) + self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.bias) self.rotary_emb = self._init_rope() def _init_rope(self): - if self.config.rotary["type"] == "origin": + if self.config.rope_scaling is None: self.rotary_emb = InternLMRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, - base=self.config.rotary["base"], - ) - elif self.config.rotary["type"] == "dynamic": - self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding( - self.head_dim, - max_position_embeddings=self.max_position_embeddings, - base=self.config.rotary["base"], - scaling_factor=self.config.rotary.get("scaling_factor", 1.0), + base=self.config.rope_theta, ) else: - raise ValueError("Currently we only support rotary embedding's type being one of ('origin', 'dynamic').") + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "dynamic": + self.rotary_emb = InternLMDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.config.rope_theta, + scaling_factor=scaling_factor + ) + else: + raise ValueError("Currently we only support rotary embedding's type being 'dynamic'.") return self.rotary_emb def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): @@ -302,21 +323,27 @@ class InternLMAttention(nn.Module): bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = ( + self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + ) + value_states = ( + self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + ) if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) - # print(use_cache) past_key_value = (key_states, value_states) if use_cache else None kv_seq_len = key_states.shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): @@ -907,6 +934,11 @@ class InternLMForCausalLM(InternLMPreTrainedModel): ('你好,有什么可以帮助您的吗', [('你好', '你好,有什么可以帮助您的吗')]) ('你好,有什么可以帮助您的吗?', [('你好', '你好,有什么可以帮助您的吗?')]) """ + if BaseStreamer is None: + raise ModuleNotFoundError( + "The version of `transformers` is too low. Please make sure " + "that you have installed `transformers>=4.28.0`." + ) response_queue = queue.Queue(maxsize=20)