fix(tools): fix streaming_chat and update docs (#467)

* move hf model to tools/transformers/internlm_model

* fix stream_chat

* Add stream_chat example

* fix import

* Add __init__ to internlm_model

* Add hf link

* fix import of tools/tokenizer.py

* fix huggingface url in readme
pull/474/head
x54-729 2023-11-03 16:12:37 +08:00 committed by GitHub
parent debb7e77b9
commit b9c813a972
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 91 additions and 82 deletions

View File

@ -103,6 +103,22 @@ Transformers を使用して InternLM 7B チャットモデルをロードする
これらの提案を実践することで、時間管理のスキルを向上させ、効果的に日々のタスクをこなしていくことができます。
```
ストリーミング生成を行いたい場合は、「stream_chat」関数を使用できます。
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model_path = "/mnt/petrelfs/share_data/xingshuhao/internlm-chat-7b/"
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.eval()
length = 0
for response, history in model.stream_chat(tokenizer, "你好", history=[]):
print(response[length:], flush=True, end="")
length = len(response)
```
### 対話
以下のコードを実行することで、フロントエンドインターフェースを通して InternLM Chat 7B モデルと対話することができます:

View File

@ -178,6 +178,22 @@ InternLM-7B 包含了一个拥有70亿参数的基础模型和一个为实际场
3. 集中注意力:避免分心,集中注意力完成任务。关闭社交媒体和电子邮件通知,专注于任务,这将帮助您更快地完成任务,并减少错误的可能性。
```
如果想进行流式生成,则可以使用 `stream_chat` 接口:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model_path = "/mnt/petrelfs/share_data/xingshuhao/internlm-chat-7b/"
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.eval()
length = 0
for response, history in model.stream_chat(tokenizer, "你好", history=[]):
print(response[length:], flush=True, end="")
length = len(response)
```
### 通过 ModelScope 加载
通过以下的代码从 ModelScope 加载 InternLM 模型 (可修改模型名称替换不同的模型)

View File

@ -175,6 +175,22 @@ Sure, here are three tips for effective time management:
Remember, good time management skills take practice and patience. Start with small steps and gradually incorporate these habits into your daily routine.
```
The responses can be streamed using `stream_chat`:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model_path = "/mnt/petrelfs/share_data/xingshuhao/internlm-chat-7b/"
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = model.eval()
length = 0
for response, history in model.stream_chat(tokenizer, "你好", history=[]):
print(response[length:], flush=True, end="")
length = len(response)
```
### Import from ModelScope
To load the InternLM model using ModelScope, use the following code:

View File

@ -8,7 +8,7 @@ import numpy as np
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, "V7_sft.model")
sys.path.append(os.path.join(current_dir, "transformers"))
from tokenization_internlm import InternLMTokenizer
from internlm_model import InternLMTokenizer
tokenizer = InternLMTokenizer(vocab_file=model_path)

View File

@ -6,8 +6,7 @@ import re
import tempfile
import torch
from modeling_internlm import InternLMConfig, InternLMForCausalLM
from tokenization_internlm import InternLMTokenizer
from internlm_model import InternLMConfig, InternLMForCausalLM, InternLMTokenizer
NUM_SHARDS = {
"7B": 1,

View File

@ -0,0 +1,3 @@
from .configuration_internlm import InternLMConfig
from .modeling_internlm import InternLMForCausalLM
from .tokenization_internlm import InternLMTokenizer

View File

@ -28,27 +28,17 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.streamers import BaseStreamer
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from configuration_internlm import InternLMConfig
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from .configuration_internlm import InternLMConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "InternLMConfig"
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
@ -106,7 +96,7 @@ class InternLMRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("inv_freq", inv_freq)
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
@ -332,11 +322,9 @@ INTERNLM_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`InternLMConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
@ -377,44 +365,33 @@ INTERNLM_INPUTS_DOCSTRING = r"""
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
@ -443,11 +420,9 @@ INTERNLM_INPUTS_DOCSTRING = r"""
class InternLMModel(InternLMPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`InternLMDecoderLayer`]
Args:
config: InternLMConfig
"""
_auto_class = "AutoModel"
def __init__(self, config: InternLMConfig):
@ -673,20 +648,14 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, InternLMForCausalLM
>>> model = InternLMForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
@ -776,58 +745,50 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
for layer_past in past_key_values:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
def build_inputs(self, tokenizer, query: str, history: List[Tuple[str, str]] = []):
prompt = ""
for record in history:
prompt += f"""<s><|User|>:{record[0]}<eoh>\n<|Bot|>:{record[1]}<eoa>\n"""
if len(prompt) == 0:
prompt += "<s>"
prompt += f"""<|User|>:{record[0]}<eoh>\n<|Bot|>:{record[1]}<eoa>\n"""
prompt += f"""<|User|>:{query}<eoh>\n<|Bot|>:"""
return tokenizer([prompt], return_tensors="pt")
@torch.no_grad()
def chat(
self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = [],
streamer: Optional[BaseStreamer] = None,
max_new_tokens: int = 1024,
do_sample: bool = True,
temperature: float = 0.8,
top_p: float = 0.8,
**kwargs,
):
def chat(self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = [],
streamer: Optional[BaseStreamer] = None,
max_new_tokens: int = 1024,
do_sample: bool = True,
temperature: float = 0.8,
top_p: float = 0.8,
**kwargs):
inputs = self.build_inputs(tokenizer, query, history)
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
outputs = self.generate(
**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
**kwargs,
)
outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
outputs = self.generate(**inputs,
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
**kwargs)
outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]):]
response = tokenizer.decode(outputs, skip_special_tokens=True)
response = response.split("<eoa>")[0]
history = history + [(query, response)]
return response, history
@torch.no_grad()
def stream_chat(
self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = [],
max_new_tokens: int = 1024,
do_sample: bool = True,
temperature: float = 0.8,
top_p: float = 0.8,
**kwargs,
):
def stream_chat(self,
tokenizer,
query: str,
history: List[Tuple[str, str]] = [],
max_new_tokens: int = 1024,
do_sample: bool = True,
temperature: float = 0.8,
top_p: float = 0.8,
**kwargs):
"""
Return a generator in format: (response, history)
Eg.
@ -873,12 +834,12 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
tokenizer=tokenizer,
query=query,
streamer=ChatStreamer(tokenizer=tokenizer),
history=history,
history=history,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
**kwargs,
**kwargs
)
def consumer():
@ -886,7 +847,7 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
producer.start()
while True:
res = response_queue.get()
if res is not None:
if res is None:
return
yield res
@ -896,10 +857,8 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
@add_start_docstrings(
"""
The InternLM Model transformer with a sequence classification head on top (linear layer).
[`InternLMForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the