modifications by pre-commit hook

pull/537/head
ly015 2023-12-12 16:17:56 +08:00
parent 4187bfbfe8
commit d4a81fad5d
50 changed files with 175 additions and 160 deletions

View File

@ -103,4 +103,3 @@ msgstr ""
#~ msgid "traning dataloader object" #~ msgid "traning dataloader object"
#~ msgstr "" #~ msgstr ""

View File

@ -47,4 +47,3 @@ msgstr "Training Results"
#: ../../source/example/30B_demo.rst:175 615a3481b0aa49729b7219b1365519aa #: ../../source/example/30B_demo.rst:175 615a3481b0aa49729b7219b1365519aa
msgid "基于以上训练配置和启动命令,两节点 16GPU 下的模型训练部分日志展示如下:" msgid "基于以上训练配置和启动命令,两节点 16GPU 下的模型训练部分日志展示如下:"
msgstr "Taking the configuration of the demo training on two nodes with 16 GPUs on slurm as an example, the training result log is shown below:" msgstr "Taking the configuration of the demo training on two nodes with 16 GPUs on slurm as an example, the training result log is shown below:"

View File

@ -47,4 +47,3 @@ msgstr "Training Results"
#: ../../source/example/7B_demo.rst:173 33ec81f34e3c4340beacdb5254069d08 #: ../../source/example/7B_demo.rst:173 33ec81f34e3c4340beacdb5254069d08
msgid "基于以上训练配置和启动命令,单节点 8GPU 下的模型训练部分日志展示如下:" msgid "基于以上训练配置和启动命令,单节点 8GPU 下的模型训练部分日志展示如下:"
msgstr "Taking the configuration of the demo training on a single machine with 8 GPUs on slurm as an example, the training result log is shown below:" msgstr "Taking the configuration of the demo training on a single machine with 8 GPUs on slurm as an example, the training result log is shown below:"

View File

@ -30,4 +30,3 @@ msgstr ""
#: ../../source/example/index.rst:13 b095e27dfc924a7a943b7cba5361700a #: ../../source/example/index.rst:13 b095e27dfc924a7a943b7cba5361700a
msgid "30B Demo" msgid "30B Demo"
msgstr "" msgstr ""

View File

@ -78,4 +78,3 @@ msgstr ""
#: ../../source/index.rst:95 a164b772960f4ab8b18c7e8820f69f55 #: ../../source/index.rst:95 a164b772960f4ab8b18c7e8820f69f55
msgid ":ref:`search`" msgid ":ref:`search`"
msgstr "" msgstr ""

View File

@ -245,4 +245,3 @@ msgid ""
"A tuple of ``(trainer, train_dataloader, test_dataloader, lr_scheduler)``" "A tuple of ``(trainer, train_dataloader, test_dataloader, lr_scheduler)``"
" where only ``trainer`` could not be None." " where only ``trainer`` could not be None."
msgstr "" msgstr ""

View File

@ -137,4 +137,3 @@ msgstr "For the local standard image built with dockerfile or pulled, use the fo
#: ../../../install.md:87 66613606256e4094a6be5ab2af1269ae #: ../../../install.md:87 66613606256e4094a6be5ab2af1269ae
msgid "容器内默认目录即 `/InternLM`,根据[使用文档](./usage.md)即可启动训练。" msgid "容器内默认目录即 `/InternLM`,根据[使用文档](./usage.md)即可启动训练。"
msgstr "The default directory in the container is `/InternLM`, please start training according to the [Usage](./usage.md)." msgstr "The default directory in the container is `/InternLM`, please start training according to the [Usage](./usage.md)."

View File

@ -195,4 +195,3 @@ msgstr ""
#: internlm.monitor.alert.send_feishu_msg_with_webhook:12 of #: internlm.monitor.alert.send_feishu_msg_with_webhook:12 of
msgid "An exception rasied by the HTTP post request." msgid "An exception rasied by the HTTP post request."
msgstr "" msgstr ""

View File

@ -454,4 +454,3 @@ msgstr ""
#: internlm.solver.optimizer.hybrid_zero_optim.HybridZeroOptimizer.step:7 of #: internlm.solver.optimizer.hybrid_zero_optim.HybridZeroOptimizer.step:7 of
msgid "Whether the gradient is success updated, and the gradient." msgid "Whether the gradient is success updated, and the gradient."
msgstr "" msgstr ""

View File

@ -172,4 +172,3 @@ msgstr ""
#: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler.step:1 of #: internlm.utils.simple_memory_profiler.SimpleMemoryProfiler.step:1 of
msgid "Update the memory state of the optimizer state." msgid "Update the memory state of the optimizer state."
msgstr "" msgstr ""

View File

@ -22,4 +22,3 @@ msgstr ""
#: ../../source/qa.rst:2 e3b22a39640a40cfb527068a7f4bbfc9 #: ../../source/qa.rst:2 e3b22a39640a40cfb527068a7f4bbfc9
msgid "问&答" msgid "问&答"
msgstr "Q&A" msgstr "Q&A"

View File

@ -159,4 +159,3 @@ msgstr ""
#~ msgid "InternLM训练流程图" #~ msgid "InternLM训练流程图"
#~ msgstr "InternLM training process" #~ msgstr "InternLM training process"

View File

@ -364,4 +364,3 @@ msgstr ""
#~ msgstr "" #~ msgstr ""
#~ "`load_model_only_folder` and `load_ckpt_folder` " #~ "`load_model_only_folder` and `load_ckpt_folder` "
#~ "cannot be set at the same time." #~ "cannot be set at the same time."

View File

@ -90,4 +90,3 @@ When `Activation Ckpt` is turned off, the test results are as shown in the table
<div align="left"> <div align="left">
<img src="../imgs/flops.png" width="580"/> <img src="../imgs/flops.png" width="580"/>
</div> </div>

View File

@ -87,4 +87,3 @@ InternLM中`zero1`的配置决定了优化器状态的分配范围。
<div align="left"> <div align="left">
<img src="../doc/imgs/flops.png" width="580"/> <img src="../doc/imgs/flops.png" width="580"/>
</div> </div>

View File

@ -1,3 +1,5 @@
# flake8: noqa
# This file is modified from: # This file is modified from:
# hhttps://github.com/reasoning-machines/pal/blob/main/pal/core/interface.py # hhttps://github.com/reasoning-machines/pal/blob/main/pal/core/interface.py
# #
@ -27,8 +29,8 @@ import tqdm
from datasets import load_dataset from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from tools.transformers.interface import GenerationConfig, generate_interactive
from internlm.utils.timeout import Timeout from internlm.utils.timeout import Timeout
from tools.transformers.interface import GenerationConfig, generate_interactive
def parse_args(): def parse_args():

View File

@ -19,9 +19,8 @@
# limitations under the License. # limitations under the License.
""" InternLM model configuration""" """ InternLM model configuration"""
from transformers.utils import logging
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
@ -30,9 +29,9 @@ INTERNLM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
class InternLMConfig(PretrainedConfig): class InternLMConfig(PretrainedConfig):
r""" r"""
This is the configuration class to store the configuration of a [`InternLMModel`]. It is used to instantiate an InternLM This is the configuration class to store the configuration of a [`InternLMModel`]. It is used to instantiate an
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the InternLM model according to the specified arguments, defining the model architecture. Instantiating a
defaults will yield a similar configuration to that of the InternLM-7B. configuration with the defaults will yield a similar configuration to that of the InternLM-7B.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.

View File

@ -1,6 +1,6 @@
import argparse import argparse
import math
import json import json
import math
import os import os
import re import re
import tempfile import tempfile
@ -110,7 +110,7 @@ def merge_pp(states_tp_pp):
states = states_tp_pp[tp][pp] states = states_tp_pp[tp][pp]
keys = list(states.keys()) keys = list(states.keys())
for key in keys: for key in keys:
match = re.search("\.\d+\.", key) match = re.search("\.\d+\.", key) # noqa: W605
if match is not None: if match is not None:
s, e = match.span() s, e = match.span()
layer_idx = int(key[s + 1 : e - 1]) + layer_shift layer_idx = int(key[s + 1 : e - 1]) + layer_shift
@ -126,9 +126,9 @@ def merge_pp(states_tp_pp):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--src_folder', type=str, default='~/test/') # 需要转换为hf格式的checkpoint文件夹 parser.add_argument("--src_folder", type=str, default="~/test/") # 需要转换为hf格式的checkpoint文件夹
parser.add_argument('--tgt_folder', type=str, default='~/output/') # 存放转换后checkpoint的目标文件夹 parser.add_argument("--tgt_folder", type=str, default="~/output/") # 存放转换后checkpoint的目标文件夹
parser.add_argument('--tokenizer', type=str, default='~/test/tokenizer.model') # Tokenizer 文件的路径 parser.add_argument("--tokenizer", type=str, default="~/test/tokenizer.model") # Tokenizer 文件的路径
args = parser.parse_args() args = parser.parse_args()
def load(fp): def load(fp):

View File

@ -5,7 +5,6 @@ from typing import Callable, List, Optional
import torch import torch
from torch import nn from torch import nn
from transformers import AutoModel, AutoTokenizer
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
from transformers.utils import logging from transformers.utils import logging
@ -38,12 +37,12 @@ def generate_interactive(
for k, v in inputs.items(): for k, v in inputs.items():
inputs[k] = v.cuda() inputs[k] = v.cuda()
input_ids = inputs["input_ids"] input_ids = inputs["input_ids"]
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] # noqa: F841
if generation_config is None: if generation_config is None:
generation_config = model.generation_config generation_config = model.generation_config
generation_config = copy.deepcopy(generation_config) generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) model_kwargs = generation_config.update(**kwargs)
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id # noqa: F841
if isinstance(eos_token_id, int): if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id] eos_token_id = [eos_token_id]
if additional_eos_token_id is not None: if additional_eos_token_id is not None:
@ -119,9 +118,7 @@ def generate_interactive(
# update generated ids, model inputs, and length for next step # update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = model._update_model_kwargs_for_generation( model_kwargs = model._update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False)
outputs, model_kwargs, is_encoder_decoder=False
)
unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long()) unfinished_sequences = unfinished_sequences.mul((min(next_tokens != i for i in eos_token_id)).long())
output_token_ids = input_ids[0].cpu().tolist() output_token_ids = input_ids[0].cpu().tolist()

View File

@ -1,11 +1,13 @@
import torch import torch
from moss_002_sft import collate_fn, get_dataset
from peft import LoraConfig, TaskType, get_peft_model
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from peft import get_peft_model, LoraConfig, TaskType
from transformers import get_linear_schedule_with_warmup
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm from tqdm import tqdm
from transformers import (
from moss_002_sft import get_dataset, collate_fn AutoModelForCausalLM,
AutoTokenizer,
get_linear_schedule_with_warmup,
)
model_path = "model_path" model_path = "model_path"
data_dir = "moss_002_sft" data_dir = "moss_002_sft"
@ -16,8 +18,11 @@ epochs = 5
val_per_steps = 1000 val_per_steps = 1000
lr = 9e-6 lr = 9e-6
peft_config = LoraConfig( peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM, r=32, lora_alpha=32, lora_dropout=0.1, task_type=TaskType.CAUSAL_LM,
target_modules=["gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj", "o_proj"] r=32,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj", "o_proj"],
) )
@ -29,12 +34,12 @@ model.cuda()
# dataset # dataset
train_dataset, val_dataset = get_dataset(tokenizer, data_dir, num=data_num, test_size=test_size) train_dataset, val_dataset = get_dataset(tokenizer, data_dir, num=data_num, test_size=test_size)
train_dataloader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, tokenizer)) train_dataloader = DataLoader(
train_dataset, batch_size=train_batch_size, shuffle=True, collate_fn=lambda x: collate_fn(x, tokenizer)
)
optimizer = torch.optim.AdamW(model.parameters(), lr) optimizer = torch.optim.AdamW(model.parameters(), lr)
scheduler = get_linear_schedule_with_warmup( scheduler = get_linear_schedule_with_warmup(optimizer, 1000, epochs * len(train_dataloader))
optimizer, 1000, epochs * len(train_dataloader)
)
# train # train
fp = open("output", "w") fp = open("output", "w")
@ -42,7 +47,7 @@ model.train()
for epoch in tqdm(range(epochs), desc="Traning Epoch"): for epoch in tqdm(range(epochs), desc="Traning Epoch"):
batch_bar = tqdm(train_dataloader, desc="Training Batch") batch_bar = tqdm(train_dataloader, desc="Training Batch")
for step, batch in enumerate(batch_bar): for step, batch in enumerate(batch_bar):
batch = {k:v.cuda() for k, v in batch.items()} batch = {k: v.cuda() for k, v in batch.items()}
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
output = model(**batch) output = model(**batch)
@ -58,7 +63,15 @@ for epoch in tqdm(range(epochs), desc="Traning Epoch"):
data, label = val_dataset[i] data, label = val_dataset[i]
prefix = tokenizer.decode(data.tolist(), skip_special_tokens=True) prefix = tokenizer.decode(data.tolist(), skip_special_tokens=True)
try: try:
generate = model.generate(input_ids=data.unsqueeze(0).cuda(), temperature=0.7, top_k=50, do_sample=True, repetition_penalty=1.02, max_new_tokens=100, top_p=0.9) generate = model.generate(
input_ids=data.unsqueeze(0).cuda(),
temperature=0.7,
top_k=50,
do_sample=True,
repetition_penalty=1.02,
max_new_tokens=100,
top_p=0.9,
)
text = tokenizer.decode(generate[0].tolist(), skip_special_tokens=True) text = tokenizer.decode(generate[0].tolist(), skip_special_tokens=True)
text = text.replace(prefix, "") text = text.replace(prefix, "")
fp.write(f"Prefix: {prefix}\nGenerated: {text}" + "\n---------------------------------\n") fp.write(f"Prefix: {prefix}\nGenerated: {text}" + "\n---------------------------------\n")

View File

@ -1,9 +1,11 @@
import os
import copy import copy
import os
import torch import torch
from datasets import Dataset as HFDataset
from datasets import load_dataset
from torch.utils.data import Dataset from torch.utils.data import Dataset
from datasets import load_dataset, Dataset as HFDataset
class SFTDataset(Dataset): class SFTDataset(Dataset):
# https://github.com/OpenLMLab/MOSS/blob/main/finetune_moss.py # https://github.com/OpenLMLab/MOSS/blob/main/finetune_moss.py
@ -26,21 +28,25 @@ class SFTDataset(Dataset):
return data, label return data, label
def collate_fn(batch, tokenizer): def collate_fn(batch, tokenizer):
batch_input_ids, batch_labels = [], [] batch_input_ids, batch_labels = [], []
for input_ids, label in batch: for input_ids, label in batch:
batch_input_ids.append(input_ids) batch_input_ids.append(input_ids)
batch_labels.append(label) batch_labels.append(label)
batch_input_ids = torch.nn.utils.rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.eos_token_id) batch_input_ids = torch.nn.utils.rnn.pad_sequence(
batch_input_ids, batch_first=True, padding_value=tokenizer.eos_token_id
)
batch_labels = torch.nn.utils.rnn.pad_sequence(batch_labels, batch_first=True, padding_value=-100) batch_labels = torch.nn.utils.rnn.pad_sequence(batch_labels, batch_first=True, padding_value=-100)
return { return {
"input_ids": batch_input_ids, "input_ids": batch_input_ids,
"attention_mask": (batch_input_ids == tokenizer.eos_token_id).long(), "attention_mask": (batch_input_ids == tokenizer.eos_token_id).long(),
"labels": batch_labels "labels": batch_labels,
} }
def process(sample, tokenizer, max_len): def process(sample, tokenizer, max_len):
chat = sample["plain_text"].split("<eoa>")[:-1] chat = sample["plain_text"].split("<eoa>")[:-1]
num_turns = sample["num_turns"] num_turns = sample["num_turns"]
@ -81,20 +87,20 @@ def load_data(save_dir, tokenizer, max_len, num=-1) -> HFDataset:
if os.path.exists(save_dir): if os.path.exists(save_dir):
print(f"Loading moss-002-sft from {save_dir}") print(f"Loading moss-002-sft from {save_dir}")
else: else:
print(f"Loading moss-002-sft from datasets") print("Loading moss-002-sft from datasets")
moss_sft = load_dataset("fnlp/moss-002-sft-data", split="train") moss_sft = load_dataset("fnlp/moss-002-sft-data", split="train")
moss_sft = moss_sft.map(lambda x:process(x, tokenizer, max_len), num_proc=10) moss_sft = moss_sft.map(lambda x: process(x, tokenizer, max_len), num_proc=10)
moss_sft = moss_sft.filter(lambda x:len(x["input_ids"]) != 0) moss_sft = moss_sft.filter(lambda x: len(x["input_ids"]) != 0)
moss_sft.save_to_disk(save_dir) moss_sft.save_to_disk(save_dir)
moss_sft = HFDataset.load_from_disk(save_dir) moss_sft = HFDataset.load_from_disk(save_dir)
if num != -1: if num != -1:
moss_sft = moss_sft.select(range(num)) moss_sft = moss_sft.select(range(num))
print( print(f"Load successfully, total {len(moss_sft)} samples.")
f"Load successfully, total {len(moss_sft)} samples.")
return moss_sft return moss_sft
def get_dataset(tokenizer, save_dir, max_len=1024, num=-1, test_size=0.1): def get_dataset(tokenizer, save_dir, max_len=1024, num=-1, test_size=0.1):
moss_sft_data = load_data(save_dir, tokenizer, max_len, num) moss_sft_data = load_data(save_dir, tokenizer, max_len, num)
moss_sft_split = moss_sft_data.train_test_split(test_size=test_size) moss_sft_split = moss_sft_data.train_test_split(test_size=test_size)
@ -102,4 +108,3 @@ def get_dataset(tokenizer, save_dir, max_len=1024, num=-1, test_size=0.1):
val_dataset = SFTDataset(moss_sft_split["test"]) val_dataset = SFTDataset(moss_sft_split["test"])
return train_dataset, val_dataset return train_dataset, val_dataset

View File

@ -19,26 +19,35 @@
# limitations under the License. # limitations under the License.
""" PyTorch InternLM model.""" """ PyTorch InternLM model."""
import math import math
import queue
import threading
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import threading, queue
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from configuration_internlm import InternLMConfig
from torch import nn from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.generation.streamers import BaseStreamer from transformers.generation.streamers import BaseStreamer
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from transformers.modeling_outputs import (
from configuration_internlm import InternLMConfig BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "InternLMConfig" _CONFIG_FOR_DOC = "InternLMConfig"
# Copied from transformers.models.bart.modeling_bart._make_causal_mask # Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask( def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
@ -423,7 +432,7 @@ INTERNLM_INPUTS_DOCSTRING = r"""
more detail. more detail.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
""" """ # noqa: E501
@add_start_docstrings( @add_start_docstrings(
@ -437,6 +446,7 @@ class InternLMModel(InternLMPreTrainedModel):
Args: Args:
config: InternLMConfig config: InternLMConfig
""" """
_auto_class = "AutoModel" _auto_class = "AutoModel"
def __init__(self, config: InternLMConfig): def __init__(self, config: InternLMConfig):
@ -776,7 +786,8 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
return tokenizer([prompt], return_tensors="pt") return tokenizer([prompt], return_tensors="pt")
@torch.no_grad() @torch.no_grad()
def chat(self, def chat(
self,
tokenizer, tokenizer,
query: str, query: str,
history: List[Tuple[str, str]] = [], history: List[Tuple[str, str]] = [],
@ -785,24 +796,28 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
do_sample: bool = True, do_sample: bool = True,
temperature: float = 0.8, temperature: float = 0.8,
top_p: float = 0.8, top_p: float = 0.8,
**kwargs): **kwargs,
):
inputs = self.build_inputs(tokenizer, query, history) inputs = self.build_inputs(tokenizer, query, history)
inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)} inputs = {k: v.to(self.device) for k, v in inputs.items() if torch.is_tensor(v)}
outputs = self.generate(**inputs, outputs = self.generate(
**inputs,
streamer=streamer, streamer=streamer,
max_new_tokens=max_new_tokens, max_new_tokens=max_new_tokens,
do_sample=do_sample, do_sample=do_sample,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
**kwargs) **kwargs,
outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]):] )
outputs = outputs[0].cpu().tolist()[len(inputs["input_ids"][0]) :]
response = tokenizer.decode(outputs, skip_special_tokens=True) response = tokenizer.decode(outputs, skip_special_tokens=True)
response = response.split("<eoa>")[0] response = response.split("<eoa>")[0]
history = history + [(query, response)] history = history + [(query, response)]
return response, history return response, history
@torch.no_grad() @torch.no_grad()
def stream_chat(self, def stream_chat(
self,
tokenizer, tokenizer,
query: str, query: str,
history: List[Tuple[str, str]] = [], history: List[Tuple[str, str]] = [],
@ -810,7 +825,8 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
do_sample: bool = True, do_sample: bool = True,
temperature: float = 0.8, temperature: float = 0.8,
top_p: float = 0.8, top_p: float = 0.8,
**kwargs): **kwargs,
):
""" """
Return a generator in format: (response, history) Return a generator in format: (response, history)
Eg. Eg.
@ -861,7 +877,7 @@ class InternLMForCausalLM(InternLMPreTrainedModel):
do_sample=do_sample, do_sample=do_sample,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
**kwargs **kwargs,
) )
def consumer(): def consumer():

View File

@ -24,11 +24,9 @@ from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import sentencepiece as spm import sentencepiece as spm
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
from transformers.utils import logging from transformers.utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"} VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}