diff --git a/README-ja-JP.md b/README-ja-JP.md index aeb7b02..77736fc 100644 --- a/README-ja-JP.md +++ b/README-ja-JP.md @@ -103,6 +103,22 @@ Transformers を使用して InternLM 7B チャットモデルをロードする これらの提案を実践することで、時間管理のスキルを向上させ、効果的に日々のタスクをこなしていくことができます。 ``` +ストリーミング生成を行いたい場合は、「stream_chat」関数を使用できます。 + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_path = "/mnt/petrelfs/share_data/xingshuhao/internlm-chat-7b/" +model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + +model = model.eval() +length = 0 +for response, history in model.stream_chat(tokenizer, "你好", history=[]): + print(response[length:], flush=True, end="") + length = len(response) +``` + ### 対話 以下のコードを実行することで、フロントエンドインターフェースを通して InternLM Chat 7B モデルと対話することができます: diff --git a/README-zh-Hans.md b/README-zh-Hans.md index 67946ea..edb64df 100644 --- a/README-zh-Hans.md +++ b/README-zh-Hans.md @@ -178,6 +178,22 @@ InternLM-7B 包含了一个拥有70亿参数的基础模型和一个为实际场 3. 集中注意力:避免分心,集中注意力完成任务。关闭社交媒体和电子邮件通知,专注于任务,这将帮助您更快地完成任务,并减少错误的可能性。 ``` +如果想进行流式生成,则可以使用 `stream_chat` 接口: + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_path = "/mnt/petrelfs/share_data/xingshuhao/internlm-chat-7b/" +model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + +model = model.eval() +length = 0 +for response, history in model.stream_chat(tokenizer, "你好", history=[]): + print(response[length:], flush=True, end="") + length = len(response) +``` + ### 通过 ModelScope 加载 通过以下的代码从 ModelScope 加载 InternLM 模型 (可修改模型名称替换不同的模型) diff --git a/README.md b/README.md index c3e4286..9983337 100644 --- a/README.md +++ b/README.md @@ -175,6 +175,22 @@ Sure, here are three tips for effective time management: Remember, good time management skills take practice and patience. Start with small steps and gradually incorporate these habits into your daily routine. ``` +The responses can be streamed using `stream_chat`: + +```python +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_path = "/mnt/petrelfs/share_data/xingshuhao/internlm-chat-7b/" +model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + +model = model.eval() +length = 0 +for response, history in model.stream_chat(tokenizer, "你好", history=[]): + print(response[length:], flush=True, end="") + length = len(response) +``` + ### Import from ModelScope To load the InternLM model using ModelScope, use the following code: diff --git a/tools/tokenizer.py b/tools/tokenizer.py index fc3800e..8e0a0d0 100644 --- a/tools/tokenizer.py +++ b/tools/tokenizer.py @@ -8,7 +8,7 @@ import numpy as np current_dir = os.path.dirname(os.path.abspath(__file__)) model_path = os.path.join(current_dir, "V7_sft.model") sys.path.append(os.path.join(current_dir, "transformers")) -from tokenization_internlm import InternLMTokenizer +from internlm_model import InternLMTokenizer tokenizer = InternLMTokenizer(vocab_file=model_path) diff --git a/tools/transformers/convert2hf.py b/tools/transformers/convert2hf.py index f8604df..7f3aa02 100644 --- a/tools/transformers/convert2hf.py +++ b/tools/transformers/convert2hf.py @@ -6,8 +6,7 @@ import re import tempfile import torch -from modeling_internlm import InternLMConfig, InternLMForCausalLM -from tokenization_internlm import InternLMTokenizer +from internlm_model import InternLMConfig, InternLMForCausalLM, InternLMTokenizer NUM_SHARDS = { "7B": 1, diff --git a/tools/transformers/internlm_model/__init__.py b/tools/transformers/internlm_model/__init__.py new file mode 100644 index 0000000..c549c66 --- /dev/null +++ b/tools/transformers/internlm_model/__init__.py @@ -0,0 +1,3 @@ +from .configuration_internlm import InternLMConfig +from .modeling_internlm import InternLMForCausalLM +from .tokenization_internlm import InternLMTokenizer \ No newline at end of file diff --git a/tools/transformers/configuration_internlm.py b/tools/transformers/internlm_model/configuration_internlm.py similarity index 100% rename from tools/transformers/configuration_internlm.py rename to tools/transformers/internlm_model/configuration_internlm.py diff --git a/tools/transformers/modeling_internlm.py b/tools/transformers/internlm_model/modeling_internlm.py similarity index 96% rename from tools/transformers/modeling_internlm.py rename to tools/transformers/internlm_model/modeling_internlm.py index 1dd31cd..269cdd2 100644 --- a/tools/transformers/modeling_internlm.py +++ b/tools/transformers/internlm_model/modeling_internlm.py @@ -28,27 +28,17 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, -) +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from transformers.modeling_utils import PreTrainedModel from transformers.generation.streamers import BaseStreamer -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) -from configuration_internlm import InternLMConfig +from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from .configuration_internlm import InternLMConfig logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "InternLMConfig" - # Copied from transformers.models.bart.modeling_bart._make_causal_mask def _make_causal_mask( input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 @@ -106,7 +96,7 @@ class InternLMRotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) - self.register_buffer("inv_freq", inv_freq) + self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. self.max_seq_len_cached = max_position_embeddings @@ -332,11 +322,9 @@ INTERNLM_START_DOCSTRING = r""" This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) - This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. - Parameters: config ([`InternLMConfig`]): Model configuration class with all the parameters of the model. Initializing with a config file does not @@ -377,44 +365,33 @@ INTERNLM_INPUTS_DOCSTRING = r""" input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. - [What are input IDs?](../glossary#input-ids) attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: - - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. - [What are attention masks?](../glossary#attention-mask) - Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. - If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`). - If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy. - - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. - [What are position IDs?](../glossary#position-ids) past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. - Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. - If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape `(batch_size, sequence_length)`. @@ -443,11 +420,9 @@ INTERNLM_INPUTS_DOCSTRING = r""" class InternLMModel(InternLMPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLMDecoderLayer`] - Args: config: InternLMConfig """ - _auto_class = "AutoModel" def __init__(self, config: InternLMConfig): @@ -673,20 +648,14 @@ class InternLMForCausalLM(InternLMPreTrainedModel): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - Returns: - Example: - ```python >>> from transformers import AutoTokenizer, InternLMForCausalLM - >>> model = InternLMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - >>> prompt = "Hey, are you consciours? Can you talk to me?" >>> inputs = tokenizer(prompt, return_tensors="pt") - >>> # Generate >>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] @@ -776,58 +745,50 @@ class InternLMForCausalLM(InternLMPreTrainedModel): for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past - + def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []): prompt = "" for record in history: - prompt += f"""<|User|>:{record[0]}\n<|Bot|>:{record[1]}\n""" - if len(prompt) == 0: - prompt += "" + prompt += f"""<|User|>:{record[0]}\n<|Bot|>:{record[1]}\n""" prompt += f"""<|User|>:{query}\n<|Bot|>:""" return tokenizer([prompt], return_tensors="pt") - + @torch.no_grad() - def chat( - self, - tokenizer, - query: str, - history: List[Tuple[str, str]] = [], - streamer: Optional[BaseStreamer] = None, - max_new_tokens: int = 1024, - do_sample: bool = True, - temperature: float = 0.8, - top_p: float = 0.8, - **kwargs, - ): + def chat(self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = [], + streamer: Optional[BaseStreamer] = None, + max_new_tokens: int = 1024, + do_sample: bool = True, + temperature: float = 0.8, + top_p: float = 0.8, + **kwargs): inputs = self.build_inputs(tokenizer, query, history) inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)} - outputs = self.generate( - **inputs, - streamer=streamer, - max_new_tokens=max_new_tokens, - do_sample=do_sample, - temperature=temperature, - top_p=top_p, - **kwargs, - ) - outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :] + outputs = self.generate(**inputs, + streamer=streamer, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temperature, + top_p=top_p, + **kwargs) + outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]):] response = tokenizer.decode(outputs, skip_special_tokens=True) response = response.split("")[0] history = history + [(query, response)] return response, history - + @torch.no_grad() - def stream_chat( - self, - tokenizer, - query: str, - history: List[Tuple[str, str]] = [], - max_new_tokens: int = 1024, - do_sample: bool = True, - temperature: float = 0.8, - top_p: float = 0.8, - **kwargs, - ): + def stream_chat(self, + tokenizer, + query: str, + history: List[Tuple[str, str]] = [], + max_new_tokens: int = 1024, + do_sample: bool = True, + temperature: float = 0.8, + top_p: float = 0.8, + **kwargs): """ Return a generator in format: (response, history) Eg. @@ -873,12 +834,12 @@ class InternLMForCausalLM(InternLMPreTrainedModel): tokenizer=tokenizer, query=query, streamer=ChatStreamer(tokenizer=tokenizer), - history=history, + history=history, max_new_tokens=max_new_tokens, do_sample=do_sample, temperature=temperature, top_p=top_p, - **kwargs, + **kwargs ) def consumer(): @@ -886,7 +847,7 @@ class InternLMForCausalLM(InternLMPreTrainedModel): producer.start() while True: res = response_queue.get() - if res is not None: + if res is None: return yield res @@ -896,10 +857,8 @@ class InternLMForCausalLM(InternLMPreTrainedModel): @add_start_docstrings( """ The InternLM Model transformer with a sequence classification head on top (linear layer). - [`InternLMForSequenceClassification`] uses the last token in order to do the classification, as other causal models (e.g. GPT-2) do. - Since it does classification on the last token, it requires to know the position of the last token. If a `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the diff --git a/tools/transformers/tokenization_internlm.py b/tools/transformers/internlm_model/tokenization_internlm.py similarity index 100% rename from tools/transformers/tokenization_internlm.py rename to tools/transformers/internlm_model/tokenization_internlm.py