Merge branch 'InternLM:main' into add-ja_readme

pull/43/head
Ikko Eltociear Ashimine 2023-07-14 15:14:27 +09:00 committed by GitHub
commit 0f5ee05717
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 85 additions and 136 deletions

View File

@ -1,12 +1,9 @@
blank_issues_enabled: false
contact_links:
- name: 📚 InternLM Documentation (官方文档)
url: https://internlm.readthedocs.io/en/latest/
about: Check if your question is answered in docs
- name: 💬 General questions (寻求帮助)
url: https://github.com/InternLM/InternLM/discussions
about: Ask general usage questions and discuss with other InternLM community members
- name: 🌐 Explore InternLM (官网)
url: https://https://internlm.org/
url: https://internlm.org/
about: Get know more about InternLM

View File

@ -21,6 +21,7 @@
[🛠️安装教程](./doc/install.md) |
[📊训练性能](./doc/train_performance.md) |
[👀模型库](#model-zoo) |
[🤗HuggingFace](https://huggingface.co/spaces/internlm/InternLM-Chat-7B) |
[🆕Update News](./CHANGE_LOG.md) |
[🤔Reporting Issues](https://github.com/InternLM/InternLM/issues/new)
@ -34,7 +35,7 @@
InternLM 即书生·浦语大模型包含面向实用场景的70亿参数基础模型与对话模型 InternLM-7B。模型具有以下特点
- 使用上万亿高质量料,建立模型超强知识体系;
- 使用上万亿高质量料,建立模型超强知识体系;
- 支持8k语境窗口长度实现更长输入与更强推理体验
- 通用工具调用能力,支持用户灵活自助搭建流程;

View File

@ -21,6 +21,7 @@
[🛠Installation](./doc/en/install.md) |
[📊Train Performance](./doc/en/train_performance.md) |
[👀Model](#model-zoo) |
[🤗HuggingFace](https://huggingface.co/spaces/internlm/InternLM-Chat-7B) |
[🆕Update News](./CHANGE_LOG.md) |
[🤔Reporting Issues](https://github.com/InternLM/InternLM/issues/new)

View File

@ -8,7 +8,8 @@ The required packages and corresponding version are shown as follows:
- CUDA == 11.7
- Pytorch == 1.13.1+cu117
- Transformers >= 4.25.1
- Flash-Attention == 23.05
- Flash-Attention == v1.0.5
- Apex == 23.05
- GPU with Ampere or Hopper architecture (such as H100, A100)
- Linux OS

View File

@ -8,16 +8,16 @@ Please refer to the [installation guide](./install.md) for instructions on how t
### Dataset Preparation (Pre-training)
The dataset for InternLM training consists of a series of `bin` and `meta` files. To generate the training dataset, you need to use the `tokenizer` tool to tokenize the raw text data. The tokenizer model can be imported by specifying the model path in the `tools/tokenizer.py` script. The current provided model is `V7.model`. If you want to use a different model, you can modify the model path directly in the `tokenizer.py` script.
The dataset for the InternLM training task includes a series of `bin` and `meta` files. A `tokenizer` is used to generate the training dataset from the original text files. The tokenizer model is imported by specifying the model parameter path in `tools/tokenizer.py`. Currently, `V7_sft.model` is provided to generate tokens. If you want to use a different model, you can directly modify the model parameter path in `tokenizer.py`.
You can generate the `bin` and `meta` files for your raw data by running the following command, where the `raw_data_name` parameter represents the name of your raw data file, `input_file_type` represents the format of your raw data file (currently supports `txt`, `json`, and `jsonl`), and `bin` represents the path to save the generated `bin` files.
You can run the following command to generate `bin` and `meta` files corresponding to the original data. The parameter `text_input_path` represents the path of the original text data, currently supporting `txt`, `json`, and `jsonl` formats, while `bin_output_path` represents the save path of the generated `bin` files.
```bash
$ python tools/tokenizer.py --raw_data_name your_raw_data_file_name(without suffix) --input_file_type 'txt' or 'json' or 'jsonl' --bin your_output_bin_path
$ python tools/tokenizer.py --text_input_path your_input_text_path --bin_output_path your_output_bin_path
```
Here is an example of data processing (only the data processing example for the `txt` format is provided here, the data processing process for `json` and `jsonl` is exactly the same as for `txt`):
Here is an example of data processing:
Given a file `raw_data.txt` containing the raw dataset, the raw dataset is shown below:
@ -30,7 +30,7 @@ Learn to be tolerant and understanding to establish truly harmonious interperson
You can generate the `bin` and `meta` files by running the following command:
```bash
$ python tools/tokenizer.py --raw_data_name raw_data --input_file_type 'text' --bin cn/output.bin
$ python tools/tokenizer.py --text_input_path raw_data.txt --bin_output_path cn/output.bin
```
It should be noted that the generated `bin` files need to be saved in one of the following directories: `cn`, `en`, `code`, `ja`, `ar`, or `kaoshi`, depending on the type of dataset.
@ -192,7 +192,7 @@ $ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python trai
If you want to start distributed training on torch with 8 GPUs on a single node, use the following command:
```bash
$ torchrun --nnodes=1 --nproc_per_node=8 train.py --config ./configs/7B_sft.py
$ torchrun --nnodes=1 --nproc_per_node=8 train.py --config ./configs/7B_sft.py --launcher "torch"
```
### Training Results

View File

@ -8,7 +8,8 @@
- CUDA == 11.7
- Pytorch == 1.13.1+cu117
- Transformers >= 4.25.1
- Flash-Attention == 23.05
- Flash-Attention == v1.0.5
- Apex == 23.05
- Ampere或者Hopper架构的GPU (例如H100, A100)
- Linux OS

View File

@ -7,14 +7,14 @@
### 数据准备 (预训练)
InternLM训练任务的数据集包括一系列的`bin`和`meta`文件。使用`tokenizer`从原始文本文件生成训练用数据集。通过在`tools/tokenizer.py`中指定模型参数路径的方式来导入tokenizer模型。目前提供`V7.model`来生成tokens。若想使用不同的模型可直接修改`tokernizer.py`中的模型参数路径。
InternLM训练任务的数据集包括一系列的`bin`和`meta`文件。使用`tokenizer`从原始文本文件生成训练用数据集。通过在`tools/tokenizer.py`中指定模型参数路径的方式来导入tokenizer模型。目前提供`V7_sft.model`来生成tokens。若想使用不同的模型可直接修改`tokernizer.py`中的模型参数路径。
可以运行以下命令生成原始数据对应的`bin`和`meta`文件,其中参数`raw_data_name`表示原始数据集的文件名称,`input_file_type`表示原始数据集的文件格式,目前支持`txt`、`json`和`jsonl`这三种格式,`bin`表示生成的`bin`文件的保存路径。
可以运行以下命令生成原始数据对应的`bin`和`meta`文件,其中参数`text_input_path`表示原始文本数据路径,目前支持`txt`、`json`和`jsonl`三种输入格式,`bin_output_path`表示生成的`bin`文件的保存路径。
```bash
$ python tools/tokenizer.py --raw_data_name your_raw_data_file_name(without suffix) --input_file_type 'txt' or 'json' or 'jsonl' --bin your_output_bin_path
$ python tools/tokenizer.py --text_input_path your_input_text_path --bin_output_path your_output_bin_path
```
下面是一个数据处理的例子(这里只给出了`txt`格式的数据处理例子,`json`和`jsonl`的数据处理流程和`txt`的完全一致)
下面是一个数据处理的例子:
给定一个包含原始数据集的文件`raw_data.txt`,原始数据集如下所示:
```bash
@ -25,7 +25,7 @@ $ python tools/tokenizer.py --raw_data_name your_raw_data_file_name(without suff
可以通过运行以下命令来生成`bin`和`meta`文件:
```bash
$ python tools/tokenizer.py --raw_data_name raw_data --input_file_type 'text' --bin cn/output.bin
$ python tools/tokenizer.py --text_input_path raw_data.txt --bin_output_path cn/output.bin
```
需要注意的是,生成的`bin`文件需要保存在`cn`或者`en`或者`code`或者`ja`或者`ar`或者`kaoshi`这六个目录下,以区分数据集的类型。
@ -175,7 +175,7 @@ $ srun -p internllm -N 2 -n 16 --ntasks-per-node=8 --gpus-per-task=1 python trai
若在 torch 上启动分布式运行环境,单节点 8 卡的运行命令如下所示:
```bash
$ torchrun --nnodes=1 --nproc_per_node=8 train.py --config ./configs/7B_sft.py
$ torchrun --nnodes=1 --nproc_per_node=8 train.py --config ./configs/7B_sft.py --launcher "torch"
```
### 运行结果

View File

@ -1,4 +1,4 @@
transformers>=4.25.1
transformers<4.30.0
sentencepiece
numpy
tqdm

2
third_party/apex vendored

@ -1 +1 @@
Subproject commit 8ffc901e50bbf740fdb6d5bccb17f66a6ec8604e
Subproject commit 0da3ffb92ee6fbe5336602f0e3989db1cd16f880

@ -1 +1 @@
Subproject commit d2f4324f4c56e017fbf22dc421943793a8ca6c3b
Subproject commit eff9fe6b8076df59d64d7a3f464696738a3c7c24

View File

@ -9,14 +9,14 @@
```
# tokenizer.py
生成原始数据的`bin`和`meta`文件需要使用`tokenizer`,我们通过在`tools/tokenizer.py`中指定模型参数路径的方式来导入tokenizer模型。目前我们提供了`V7.model`来生成tokens。若想使用不同的模型可直接修改`tokernizer.py`中的模型参数路径。
生成原始数据的`bin`和`meta`文件需要使用`tokenizer`,我们通过在`tools/tokenizer.py`中指定模型参数路径的方式来导入tokenizer模型。目前我们提供了`V7_sft.model`来生成tokens。若想使用不同的模型可直接修改`tokernizer.py`中的模型参数路径。
我们可以运行以下命令生成原始数据对应的`bin`和`meta`文件,其中参数`raw_data_name`表示原始数据集的文件名称,`input_file_type`表示原始数据集的文件格式,我们目前支持`txt`、`json`和`jsonl`这三种格式,`bin`表示生成的`bin`文件的保存路径。
可以运行以下命令生成原始数据对应的`bin`和`meta`文件,其中参数`text_input_path`表示原始文本数据路径,目前支持`txt`、`json`和`jsonl`三种输入格式,`bin_output_path`表示生成的`bin`文件的保存路径。
```bash
$ python tools/tokenizer.py --raw_data_name your_raw_data_file_name(without suffix) --input_file_type 'text' or 'json' or 'jsonl' --bin your_output_bin_path
$ python tools/tokenizer.py --text_input_path your_input_text_path --bin_output_path your_output_bin_path
```
下面是一个数据处理的例子(这里只给出了`txt`格式的数据处理例子,`json`和`jsonl`的数据处理流程和`txt`的完全一致)
下面是一个数据处理的例子:
给定一个包含原始数据集的文件`raw_data.txt`,原始数据集如下所示:
```bash
@ -25,9 +25,9 @@ $ python tools/tokenizer.py --raw_data_name your_raw_data_file_name(without suff
学会宽容和理解,才能建立真正和谐的人际关系。
```
接下来,我们可以通过运行以下命令来生成`bin`和`meta`文件:
可以通过运行以下命令来生成`bin`和`meta`文件:
```bash
$ python tools/tokenizer.py --raw_data_name raw_data --input_file_type 'text' --bin cn/output.bin
$ python tools/tokenizer.py --text_input_path raw_data.txt --bin_output_path cn/output.bin
```
需要注意的是,生成的`bin`文件需要保存在`cn`或者`en`或者`code`或者`ja`或者`ar`或者`kaoshi`这五个目录下,以区分数据集的类型。

View File

@ -11,12 +11,12 @@ This directory provide some tools for model training with the following file str
# tokenizer.py
We need to use a `tokenizer` to generate `bin` and `meta` files for raw data. We import the tokenizer model by specifying the model weight path in `tools/tokenizer.py`. Currently, we provide `V7.model` to generate tokens. If you want to use a different model, you can modify the model weight path in `tokenizer.py` directly.
We can run the following command to generate `bin` and `meta` files for raw data, where the parameter `raw_data_name` indicates the file name of raw data, `input_file_type` denotes the raw data format, which should be `txt`, `json` and `jsonl`, and `bin` indicates the path to save the generated `bin` file.
We can run the following command to generate `bin` and `meta` files corresponding to the original data. The parameter `text_input_path` represents the path of the original text data, currently supporting `txt`, `json`, and `jsonl` formats, while `bin_output_path` represents the save path of the generated `bin` files.
```bash
$ python tools/tokenizer.py --raw_data_name your_raw_data_file_name(without suffix) --input_file_type 'text' or 'json' or 'jsonl' --bin your_output_bin_path
$ python tools/tokenizer.py --text_input_path your_input_text_path --bin_output_path your_output_bin_path
```
An example of data processing in `txt` format is given here (the data processing for `json` and `jsonl` is identical to that for `txt`).
An example of data processing in `txt` format is given here:
Given a file `raw_data.txt` containg raw data with the following content.
```bash
@ -26,7 +26,7 @@ Learn to be tolerant and understanding to establish truly harmonious interperson
```
Next, we can run the following command to generate `bin` and `meta` files for raw data.
```bash
$ python tools/tokenizer.py --raw_data_name raw_data --input_file_type 'text' --bin cn/output.bin
$ python tools/tokenizer.py --text_input_path your_input_text_path --bin_output_path your_output_bin_path
```
It should be noted that the generated `bin` files should be placed in one of the following directories to clarify the data type: `cn`(Chinese), `en`(English), `code`(code data), `ja`(Japanese), `ar`(Arabic) and `kaoshi`(kaoshi data).

View File

@ -1,24 +1,25 @@
import argparse
import json
import os
import warnings
import sys
import numpy as np
from sentencepiece import SentencePieceProcessor
from termcolor import colored
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, "V7.model")
tokenizer = SentencePieceProcessor(model_file=model_path)
model_path = os.path.join(current_dir, "V7_sft.model")
sys.path.append(os.path.join(current_dir, "transformers"))
from tokenization_internlm import InternLMTokenizer
tokenizer = InternLMTokenizer(vocab_file=model_path)
def write_bin(context: str, path: str) -> None:
def write_bin(context: str, bin_file) -> None:
"""
Write bin file.
Write bin file based on the context.
Args:
context (str): the context of raw file.
path (str): the path for output bin file.
bin_file (file handler): the opened bin file.
Example:
>>> write_bin("今天天气晴朗适合出门散步", "out.bin") # the output file format is 'txt'
@ -33,21 +34,20 @@ def write_bin(context: str, path: str) -> None:
# encode the data into bytes to save
saved_bin = str.encode(json.dumps(data) + "\n")
# write bytes into bin path
with open(path, "ab") as f:
f.write(saved_bin)
# write bytes into bin_file
bin_file.write(saved_bin)
def prepare_meta(bin_file_path: str):
def prepare_meta(bin_output_path: str):
"""
Prepare metadata for the given bin file.
Args:
bin_file_path (str): the bin file path.
bin_output_path (str): Output bin file path.
"""
meta = []
cur = 0
with open(bin_file_path, "rb") as f:
with open(bin_output_path, "rb") as f:
while True:
# read lines
line = f.readline()
@ -62,109 +62,66 @@ def prepare_meta(bin_file_path: str):
meta.append((cur, length))
# update the cur to generate the meta information of next line
cur += len(line)
print(meta)
# define path of the generated meta file
meta_fp = bin_file_path + ".meta"
meta_fp = bin_output_path + ".meta"
# save the generated meta information
with open(meta_fp, "wb") as f:
meta = np.array(meta, dtype=np.int32)
np.save(f, meta)
def txt2bin(txt_file_path: str, bin_file_path: str):
def text2bin(text_input_path: str, bin_output_path: str):
"""
Read content from txt file and write to bin file
Read content from the input file and write to bin file.
Currently support 3 input formats: 'txt', 'json' and 'jsonl'.
Args:
txt_file_path (str): txt file path.
bin_file_path (str): output bin file path.
text_input_path (str): txt file path.
bin_output_path (str): output bin file path.
"""
# Check if the txt file exists
if not os.path.isfile(txt_file_path):
warnings.warn(colored(f"{txt_file_path} does not exist.", "red"))
return
if not os.path.isfile(text_input_path):
raise FileNotFoundError(f"{text_input_path} does not exist.")
try:
# Open the text file
with open(txt_file_path, "r") as txt_file:
for line in txt_file:
file_format = text_input_path.split(".")[-1]
assert file_format in ["txt", "json", "jsonl"], print(
"Invalid input file type. Currently support `txt`, `json` and `jsonl`."
)
with open(text_input_path, "r") as text_file, open(bin_output_path, "ab") as bin_file:
if file_format == "txt":
for line in text_file:
# Strip any leading/trailing whitespace
stripped_line = line.strip()
if stripped_line:
# Pass each line to the write_bin function
write_bin(stripped_line, bin_file_path)
write_bin(stripped_line, bin_file)
print(colored(f"Successfully converted {txt_file_path} to {bin_file_path}", "green"))
except Exception as e:
print(colored(f"Error while converting {txt_file_path} to {bin_file_path}: {str(e)}", "red"))
def json2bin(json_file_path: str, bin_file_path: str):
"""
Read content from json file and write to bin file
Args:
json_file_path (str): json file path.
bin_file_path (str): output bin file path.
"""
if not os.path.isfile(json_file_path):
warnings.warn(colored(f"{json_file_path} does not exist.", "red"))
return
try:
# load json file
with open(json_file_path, "r") as json_file:
data = json.load(json_file)
elif file_format == "json":
data = json.load(text_file)
# assuming data is a list of dictionaries
for record in data:
# the type of record is dict, transfer the dict into str
context = json.dumps(record)
# encode the str and write into bin
write_bin(context, bin_file_path)
write_bin(context, bin_file)
print(colored(f"Successfully converted {json_file_path} to {bin_file_path}", "green"))
except Exception as e:
print(colored(f"Error while converting {json_file_path} to {bin_file_path}: {str(e)}", "red"))
def jsonl2bin(jsonl_file_path: str, bin_file_path: str):
"""
Read content from jsonl file and write to bin file
Args:
jsonl_file_path: jsonl file path.
bin_file_path: bin file path.
"""
if not os.path.isfile(jsonl_file_path):
warnings.warn(colored(f"{jsonl_file_path} does not exist.", "red"))
return
try:
with open(jsonl_file_path, "r") as jsonl_file:
for line in jsonl_file:
elif file_format == "jsonl":
for line in text_file:
# encode the str and write into bin
write_bin(line, bin_file_path)
print(colored(f"Successfully converted {jsonl_file_path} to {bin_file_path}", "green"))
except Exception as e:
print(colored(f"Error while converting {jsonl_file_path} to {bin_file_path}: {str(e)}", "red"))
write_bin(line, bin_file)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--raw_data_name", required=True, help="Input file name")
parser.add_argument(
"--input_file_type",
choices=["txt", "json", "jsonl"],
"--text_input_path",
type=str,
required=True,
help="Input file format (either txt, json or jsonl)",
help="Path to the input text file.",
)
parser.add_argument("--bin", required=True, help="Path to the output bin file")
parser.add_argument("--bin_output_path", type=str, required=True, help="Path to the output bin file.")
return parser.parse_args()
@ -173,21 +130,12 @@ def main():
# parse arguments
args = parse_args()
# obtain the raw data path
input_file_path = f"{args.raw_data_name}.{args.input_file_type}"
# different methods for different raw data type, we only support "txt", "json" and "jsonl" data type.
if args.input_file_type == "txt":
txt2bin(input_file_path, args.bin)
elif args.input_file_type == "json":
json2bin(input_file_path, args.bin)
elif args.input_file_type == "jsonl":
jsonl2bin(input_file_path, args.bin)
else:
print(colored("Invalid input file type. Use --help for more information.", "red"))
text2bin(args.text_input_path, args.bin_output_path)
print(f"Successfully converted {args.text_input_path} to {args.bin_output_path}")
# To avoid potential read/write errors, the metadata preparation follows after creating the .bin file.
prepare_meta(args.bin)
prepare_meta(args.bin_output_path)
print(f"Successfully generated {args.bin_output_path}.meta")
if __name__ == "__main__":